Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
609 additions
and
72 deletions
+609
-72
test/verify/test_scatter0.cpp
test/verify/test_scatter0.cpp
+1
-1
test/verify/test_scatter1.cpp
test/verify/test_scatter1.cpp
+1
-1
test/verify/test_scatternd.cpp
test/verify/test_scatternd.cpp
+30
-0
test/verify/test_scatternd_add.cpp
test/verify/test_scatternd_add.cpp
+30
-0
test/verify/test_scatternd_mul.cpp
test/verify/test_scatternd_mul.cpp
+28
-0
test/verify/test_sqrt_half1.cpp
test/verify/test_sqrt_half1.cpp
+20
-0
test/verify/test_sqrt_half2.cpp
test/verify/test_sqrt_half2.cpp
+21
-0
test/verify/test_sqrt_half4.cpp
test/verify/test_sqrt_half4.cpp
+20
-0
test/verify/test_sub_int.cpp
test/verify/test_sub_int.cpp
+21
-0
tools/api.py
tools/api.py
+326
-46
tools/api/api.cpp
tools/api/api.cpp
+54
-0
tools/api/migraphx.h
tools/api/migraphx.h
+4
-2
tools/generate.sh
tools/generate.sh
+4
-2
tools/include/context.hpp
tools/include/context.hpp
+11
-1
tools/include/schedule_model.hpp
tools/include/schedule_model.hpp
+6
-6
tools/install_prereqs.sh
tools/install_prereqs.sh
+13
-3
tools/te.py
tools/te.py
+19
-10
No files found.
test/verify/test_scatter0.cpp
View file @
11e155c2
...
@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0>
...
@@ -18,7 +18,7 @@ struct test_scatter0 : verify_program<test_scatter0>
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter"
,
{{
"axis"
,
-
1
}}),
pd
,
li
,
pu
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter
_none
"
,
{{
"axis"
,
-
1
}}),
pd
,
li
,
pu
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
...
...
test/verify/test_scatter1.cpp
View file @
11e155c2
...
@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1>
...
@@ -19,7 +19,7 @@ struct test_scatter1 : verify_program<test_scatter1>
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
pd
=
mm
->
add_parameter
(
"data"
,
sd
);
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
li
=
mm
->
add_literal
(
migraphx
::
literal
{
si
,
vi
});
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
pu
=
mm
->
add_parameter
(
"update"
,
su
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter"
,
{{
"axis"
,
-
2
}}),
pd
,
li
,
pu
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatter
_none
"
,
{{
"axis"
,
-
2
}}),
pd
,
li
,
pu
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
...
...
test/verify/test_scatternd.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd
:
verify_program
<
test_scatternd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
1
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
ld
=
mm
->
add_literal
(
migraphx
::
literal
{
ds
,
{
1
}});
auto
data
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
8
}}}),
ld
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_add.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_add
:
verify_program
<
test_scatternd_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
1
,
4
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
t_ind
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
indices
);
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
data
,
t_ind
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_scatternd_mul.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_scatternd_mul
:
verify_program
<
test_scatternd_mul
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
dtype
=
migraphx
::
shape
::
float_type
;
auto
itype
=
migraphx
::
shape
::
int64_type
;
migraphx
::
shape
ds
{
dtype
,
{
8
}};
migraphx
::
shape
is
{
itype
,
{
4
,
1
}};
migraphx
::
shape
us
{
dtype
,
{
4
}};
std
::
vector
<
int64_t
>
ind_vec
{
4
,
3
,
1
,
7
};
auto
data
=
mm
->
add_parameter
(
"data"
,
ds
);
auto
indices
=
mm
->
add_literal
(
migraphx
::
literal
{
is
,
ind_vec
});
auto
updates
=
mm
->
add_parameter
(
"update"
,
us
);
auto
scatternd
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
data
,
indices
,
updates
);
mm
->
add_return
({
scatternd
});
return
p
;
}
};
test/verify/test_sqrt_half1.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with odd size tensor can't fit half2 packing
struct
test_sqrt_half1
:
verify_program
<
test_sqrt_half1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
5
}};
auto
param
=
mm
->
add_parameter
(
"x"
,
s
);
auto
param_abs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"abs"
),
param
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
param_abs
);
return
p
;
}
};
test/verify/test_sqrt_half2.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that's divisible by 2,
// but not divisible by 4
struct
test_sqrt_half2
:
verify_program
<
test_sqrt_half2
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
6
}};
auto
param
=
mm
->
add_parameter
(
"x"
,
s
);
auto
param_abs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"abs"
),
param
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
param_abs
);
return
p
;
}
};
test/verify/test_sqrt_half4.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
// math op on half-precision float with tensor size that fits into half4 packing
struct
test_sqrt_half4
:
verify_program
<
test_sqrt_half4
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
half_type
,
{
8
}};
auto
param
=
mm
->
add_parameter
(
"x"
,
s
);
auto
param_abs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"abs"
),
param
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
param_abs
);
return
p
;
}
};
test/verify/test_sub_int.cpp
0 → 100644
View file @
11e155c2
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
test_sub_int
:
verify_program
<
test_sub_int
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int16_type
,
{
4
,
5
}});
auto
y
=
mm
->
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int16_type
,
{
2
,
3
,
4
,
5
}});
auto
xb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
4
,
5
}}}),
x
);
auto
diff
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"sub"
),
y
,
xb
);
mm
->
add_return
({
diff
});
return
p
;
}
};
tools/api.py
View file @
11e155c2
...
@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = []
...
@@ -15,7 +15,7 @@ c_api_body_preamble: List[str] = []
cpp_header_preamble
:
List
[
str
]
=
[]
cpp_header_preamble
:
List
[
str
]
=
[]
def
bad_param_error
(
msg
):
def
bad_param_error
(
msg
:
str
):
return
'throw std::runtime_error("{}")'
.
format
(
msg
)
return
'throw std::runtime_error("{}")'
.
format
(
msg
)
...
@@ -89,7 +89,7 @@ class Type:
...
@@ -89,7 +89,7 @@ class Type:
else
:
else
:
return
t
.
remove_const
()
return
t
.
remove_const
()
def
const_compatible
(
self
,
t
):
def
const_compatible
(
self
,
t
:
'Type'
):
if
t
.
is_const
():
if
t
.
is_const
():
return
self
.
add_const
()
return
self
.
add_const
()
return
self
return
self
...
@@ -102,6 +102,10 @@ header_function = Template('''
...
@@ -102,6 +102,10 @@ header_function = Template('''
${error_type} ${name}(${params});
${error_type} ${name}(${params});
'''
)
'''
)
function_pointer_typedef
=
Template
(
'''
typedef ${error_type} (*${fname})(${params});
'''
)
c_api_impl
=
Template
(
'''
c_api_impl
=
Template
(
'''
extern "C" ${error_type} ${name}(${params})
extern "C" ${error_type} ${name}(${params})
{
{
...
@@ -136,18 +140,23 @@ class CFunction:
...
@@ -136,18 +140,23 @@ class CFunction:
self
.
va_end
=
[
'va_end({});'
.
format
(
name
)]
self
.
va_end
=
[
'va_end({});'
.
format
(
name
)]
self
.
add_param
(
'...'
,
''
)
self
.
add_param
(
'...'
,
''
)
def
substitute
(
self
,
form
:
Template
)
->
str
:
def
substitute
(
self
,
form
:
Template
,
**
kwargs
)
->
str
:
return
form
.
substitute
(
error_type
=
error_type
,
return
form
.
substitute
(
error_type
=
error_type
,
try_wrap
=
try_wrap
,
try_wrap
=
try_wrap
,
name
=
self
.
name
,
name
=
self
.
name
,
params
=
', '
.
join
(
self
.
params
),
params
=
', '
.
join
(
self
.
params
),
body
=
";
\n
"
.
join
(
self
.
body
),
body
=
";
\n
"
.
join
(
self
.
body
),
va_start
=
"
\n
"
.
join
(
self
.
va_start
),
va_start
=
"
\n
"
.
join
(
self
.
va_start
),
va_end
=
"
\n
"
.
join
(
self
.
va_end
))
va_end
=
"
\n
"
.
join
(
self
.
va_end
),
**
kwargs
)
def
generate_header
(
self
)
->
str
:
def
generate_header
(
self
)
->
str
:
return
self
.
substitute
(
header_function
)
return
self
.
substitute
(
header_function
)
def
generate_function_pointer
(
self
,
name
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
function_pointer_typedef
,
fname
=
name
or
self
.
name
)
def
generate_body
(
self
)
->
str
:
def
generate_body
(
self
)
->
str
:
return
self
.
substitute
(
c_api_impl
)
return
self
.
substitute
(
c_api_impl
)
...
@@ -163,7 +172,9 @@ class Parameter:
...
@@ -163,7 +172,9 @@ class Parameter:
name
:
str
,
name
:
str
,
type
:
str
,
type
:
str
,
optional
:
bool
=
False
,
optional
:
bool
=
False
,
returns
:
bool
=
False
)
->
None
:
returns
:
bool
=
False
,
virtual
:
bool
=
False
,
this
:
bool
=
False
)
->
None
:
self
.
name
=
name
self
.
name
=
name
self
.
type
=
Type
(
type
)
self
.
type
=
Type
(
type
)
self
.
optional
=
optional
self
.
optional
=
optional
...
@@ -175,7 +186,11 @@ class Parameter:
...
@@ -175,7 +186,11 @@ class Parameter:
self
.
cpp_read
=
'${name}'
self
.
cpp_read
=
'${name}'
self
.
cpp_write
=
'${name}'
self
.
cpp_write
=
'${name}'
self
.
returns
=
returns
self
.
returns
=
returns
self
.
virtual
=
virtual
self
.
this
=
this
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
self
.
virtual_read
:
Optional
[
List
[
str
]]
=
None
self
.
virtual_write
:
Optional
[
str
]
=
None
def
get_name
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
def
get_name
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
prefix
:
if
prefix
:
...
@@ -248,6 +263,48 @@ class Parameter:
...
@@ -248,6 +263,48 @@ class Parameter:
raise
ValueError
(
"Error for {}: write cannot be a string"
.
format
(
raise
ValueError
(
"Error for {}: write cannot be a string"
.
format
(
self
.
type
.
str
()))
self
.
type
.
str
()))
def
virtual_arg
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]:
read
=
self
.
virtual_read
if
not
read
and
len
(
self
.
write
)
>=
len
(
self
.
cparams
):
read
=
[
Template
(
w
.
partition
(
'='
)[
2
]).
safe_substitute
(
result
=
'${name}'
)
for
w
in
self
.
write
]
if
not
read
:
raise
ValueError
(
"No virtual_read parameter provided for: "
+
self
.
type
.
str
())
if
isinstance
(
read
,
str
):
raise
ValueError
(
"Error for {}: virtual_read cannot be a string"
.
format
(
self
.
type
.
str
()))
return
[
self
.
substitute
(
r
,
prefix
=
prefix
)
for
r
in
read
]
def
virtual_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
'${type} ${name}'
,
prefix
=
prefix
)
def
virtual_output_args
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]:
return
[
'&{prefix}{n}'
.
format
(
prefix
=
prefix
or
''
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
virtual_output_declarations
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]:
return
[
'std::remove_pointer_t<{type}> {prefix}{n};'
.
format
(
type
=
Type
(
t
).
str
(),
prefix
=
prefix
or
''
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
virtual_output
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
write
=
self
.
virtual_write
if
not
write
:
if
'*'
in
self
.
read
or
'->'
in
self
.
read
:
write
=
Template
(
self
.
read
).
safe_substitute
(
name
=
'(&${name})'
)
else
:
write
=
self
.
read
return
self
.
substitute
(
write
,
prefix
=
prefix
)
def
cpp_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
def
cpp_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
'${cpptype} ${name}'
,
prefix
=
prefix
)
return
self
.
substitute
(
'${cpptype} ${name}'
,
prefix
=
prefix
)
...
@@ -311,6 +368,7 @@ class Function:
...
@@ -311,6 +368,7 @@ class Function:
invoke
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
fname
:
Optional
[
str
]
=
None
,
fname
:
Optional
[
str
]
=
None
,
return_name
:
Optional
[
str
]
=
None
,
return_name
:
Optional
[
str
]
=
None
,
virtual
:
bool
=
False
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
self
.
name
=
name
self
.
name
=
name
self
.
params
=
params
or
[]
self
.
params
=
params
or
[]
...
@@ -321,6 +379,10 @@ class Function:
...
@@ -321,6 +379,10 @@ class Function:
self
.
return_name
=
return_name
or
'out'
self
.
return_name
=
return_name
or
'out'
self
.
returns
=
Parameter
(
self
.
return_name
,
returns
,
self
.
returns
=
Parameter
(
self
.
return_name
,
returns
,
returns
=
True
)
if
returns
else
None
returns
=
True
)
if
returns
else
None
for
p
in
self
.
params
:
p
.
virtual
=
virtual
if
self
.
returns
:
self
.
returns
.
virtual
=
virtual
def
share_params
(
self
)
->
None
:
def
share_params
(
self
)
->
None
:
if
self
.
shared_size
==
True
:
if
self
.
shared_size
==
True
:
...
@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None,
...
@@ -556,6 +618,9 @@ def params(virtual: Optional[Dict[str, str]] = None,
return
result
return
result
gparams
=
params
def
add_function
(
name
:
str
,
*
args
,
**
kwargs
)
->
Function
:
def
add_function
(
name
:
str
,
*
args
,
**
kwargs
)
->
Function
:
f
=
Function
(
name
,
*
args
,
**
kwargs
)
f
=
Function
(
name
,
*
args
,
**
kwargs
)
functions
.
append
(
f
)
functions
.
append
(
f
)
...
@@ -627,7 +692,7 @@ extern "C" struct ${ctype};
...
@@ -627,7 +692,7 @@ extern "C" struct ${ctype};
struct ${ctype} {
struct ${ctype} {
template<class... Ts>
template<class... Ts>
${ctype}(Ts&&... xs)
${ctype}(Ts&&... xs)
: object(std::forward<Ts>(xs)...)
: object(std::forward<Ts>(xs)...)
// NOLINT(readability-redundant-member-init)
{}
{}
${cpptype} object;
${cpptype} object;
};
};
...
@@ -656,6 +721,55 @@ void destroy(T* x)
...
@@ -656,6 +721,55 @@ void destroy(T* x)
{
{
delete x; // NOLINT
delete x; // NOLINT
}
}
// TODO: Move to interface preamble
template <class C, class D>
struct manage_generic_ptr
{
manage_generic_ptr() = default;
manage_generic_ptr(std::nullptr_t)
{
}
manage_generic_ptr(void* pdata, C pcopier, D pdeleter)
: data(nullptr), copier(pcopier), deleter(pdeleter)
{
copier(&data, pdata);
}
manage_generic_ptr(const manage_generic_ptr& rhs)
: data(nullptr), copier(rhs.copier), deleter(rhs.deleter)
{
if(copier)
copier(&data, rhs.data);
}
manage_generic_ptr(manage_generic_ptr&& other) noexcept
: data(other.data), copier(other.copier), deleter(other.deleter)
{
other.data = nullptr;
other.copier = nullptr;
other.deleter = nullptr;
}
manage_generic_ptr& operator=(manage_generic_ptr rhs)
{
std::swap(data, rhs.data);
std::swap(copier, rhs.copier);
std::swap(deleter, rhs.deleter);
return *this;
}
~manage_generic_ptr()
{
if(data != nullptr)
deleter(data);
}
void* data = nullptr;
C copier = nullptr;
D deleter = nullptr;
};
'''
'''
cpp_handle_preamble
=
'''
cpp_handle_preamble
=
'''
...
@@ -718,38 +832,53 @@ def add_handle(name: str,
...
@@ -718,38 +832,53 @@ def add_handle(name: str,
ctype
:
str
,
ctype
:
str
,
cpptype
:
str
,
cpptype
:
str
,
destroy
:
Optional
[
str
]
=
None
,
destroy
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
ref
=
False
,
skip_def
=
False
)
->
None
:
opaque_type
=
ctype
+
'_t'
opaque_type
=
ctype
+
'_t'
const_opaque_type
=
'const_'
+
opaque_type
def
handle_wrap
(
p
):
def
handle_wrap
(
p
:
Parameter
):
t
=
Type
(
opaque_type
)
t
=
Type
(
opaque_type
)
if
p
.
type
.
is_const
():
if
p
.
type
.
is_const
():
t
=
Type
(
'const_'
+
opaque_type
)
t
=
Type
(
'const_'
+
opaque_type
)
if
p
.
returns
:
# p.read = 'object_cast<${ctype}>(&(${name}))'
if
p
.
virtual
:
p
.
add_param
(
t
)
elif
p
.
returns
:
p
.
add_param
(
t
.
add_pointer
())
p
.
add_param
(
t
.
add_pointer
())
if
p
.
type
.
is_reference
():
p
.
cpp_write
=
'${cpptype}(${name}, false)'
p
.
write
=
[
'*${name} = object_cast<${ctype}>(&(${result}))'
]
elif
p
.
type
.
is_pointer
():
p
.
cpp_write
=
'${cpptype}(${name}, false)'
p
.
write
=
[
'*${name} = object_cast<${ctype}>(${result})'
]
else
:
p
.
cpp_write
=
'${cpptype}(${name})'
p
.
write
=
[
'*${name} = allocate<${ctype}>(${result})'
]
else
:
else
:
p
.
add_param
(
t
)
p
.
add_param
(
t
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
if
p
.
type
.
is_reference
():
p
.
virtual_read
=
[
'object_cast<${ctype}>(&(${name}))'
]
p
.
cpp_write
=
'${cpptype}(${name}, false)'
p
.
write
=
[
'*${name} = object_cast<${ctype}>(&(${result}))'
]
elif
p
.
type
.
is_pointer
():
p
.
virtual_read
=
[
'object_cast<${ctype}>(${result})'
]
p
.
cpp_write
=
'${cpptype}(${name}, false)'
p
.
write
=
[
'*${name} = object_cast<${ctype}>(${result})'
]
else
:
p
.
virtual_read
=
[
'object_cast<${ctype}>(&(${name}))'
]
p
.
cpp_write
=
'${cpptype}(${name})'
p
.
write
=
[
'*${name} = allocate<${ctype}>(${result})'
]
if
skip_def
:
p
.
read
=
'*${name}'
else
:
p
.
read
=
'${name}->object'
p
.
read
=
'${name}->object'
p
.
cpp_read
=
'${name}.get_handle_ptr()'
p
.
cpp_read
=
'${name}.get_handle_ptr()'
type_map
[
cpptype
]
=
handle_wrap
type_map
[
cpptype
]
=
handle_wrap
if
not
ref
:
if
not
ref
:
add_function
(
destroy
or
ctype
+
'_'
+
'destroy'
,
add_function
(
destroy
or
ctype
+
'_'
+
'destroy'
,
params
({
name
:
opaque_type
}),
params
({
name
:
opaque_type
}),
fname
=
'destroy'
)
fname
=
'destroy'
)
add_function
(
ctype
+
'_'
+
'assign_to'
,
params
(
output
=
opaque_type
,
input
=
const_opaque_type
),
invoke
=
'*output = *input'
)
add_handle_preamble
()
add_handle_preamble
()
c_header_preamble
.
append
(
handle_typedef
.
substitute
(
locals
()))
c_header_preamble
.
append
(
handle_typedef
.
substitute
(
locals
()))
c_api_body_preamble
.
append
(
handle_definition
.
substitute
(
locals
()))
if
not
skip_def
:
c_api_body_preamble
.
append
(
handle_definition
.
substitute
(
locals
()))
@
cwrap
(
'std::vector'
)
@
cwrap
(
'std::vector'
)
...
@@ -759,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None:
...
@@ -759,30 +888,32 @@ def vector_c_wrap(p: Parameter) -> None:
if
not
inner
:
if
not
inner
:
return
return
t
=
inner
.
add_pointer
()
t
=
inner
.
add_pointer
()
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_const
():
t
=
t
.
add_const
()
if
p
.
returns
:
if
p
.
returns
:
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_const
():
t
=
t
.
add_const
()
p
.
add_param
(
t
.
add_pointer
())
p
.
add_param
(
t
.
add_pointer
())
p
.
add_size_param
()
p
.
add_size_param
()
p
.
bad_param
(
'${name} == nullptr or ${size} == nullptr'
,
p
.
bad_param
(
'${name} == nullptr or ${size} == nullptr'
,
'Null pointer'
)
'Null pointer'
)
p
.
cpp_write
=
'${type}(${name}, ${name}+${size})'
p
.
write
=
[
'*${name} = ${result}.data()'
,
'*${size} = ${result}.size()'
]
else
:
else
:
p
.
add_param
(
t
)
p
.
add_param
(
t
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
cpp_write
=
'${type}(${name}, ${name}+${size})'
p
.
write
=
[
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
else
:
else
:
p
.
add_param
(
t
)
p
.
add_param
(
t
)
p
.
add_size_param
()
p
.
add_size_param
()
p
.
bad_param
(
'${name} == nullptr and ${size} != 0'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr and ${size} != 0'
,
'Null pointer'
)
p
.
read
=
'${type}(${name}, ${name}+${size})'
p
.
read
=
'${type}(${name}, ${name}+${size})'
p
.
cpp_write
=
'${type}(${name}, ${name}+${size})'
p
.
virtual_read
=
[
'${name}.data()'
,
'${name}.size()'
]
if
p
.
type
.
is_reference
():
p
.
write
=
[
'*${name} = ${result}.data()'
,
'*${size} = ${result}.size()'
]
else
:
p
.
write
=
[
'std::copy(${result}.begin(), ${result}.end(), ${name})'
]
@
cwrap
(
'std::string'
)
@
cwrap
(
'std::string'
)
...
@@ -792,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None:
...
@@ -792,34 +923,34 @@ def string_c_wrap(p: Parameter) -> None:
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_reference
():
p
.
add_param
(
t
.
add_pointer
())
p
.
add_param
(
t
.
add_pointer
())
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
cpp_write
=
'${type}(${name})'
p
.
write
=
[
'*${name} = ${result}.c_str()'
]
else
:
else
:
p
.
add_param
(
t
)
p
.
add_param
(
t
)
p
.
add_param
(
'size_t'
,
p
.
name
+
'_size'
)
p
.
add_param
(
'size_t'
,
p
.
name
+
'_size'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
cpp_write
=
'${type}(${name})'
p
.
write
=
[
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it =
\'\\
0
\'
'
]
else
:
else
:
p
.
add_param
(
t
)
p
.
add_param
(
t
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
bad_param
(
'${name} == nullptr'
,
'Null pointer'
)
p
.
read
=
'${type}(${name})'
p
.
read
=
'${type}(${name})'
p
.
cpp_write
=
'${type}(${name})'
p
.
virtual_read
=
[
'${name}.c_str()'
]
if
p
.
type
.
is_reference
():
p
.
write
=
[
'*${name} = ${result}.c_str()'
]
else
:
p
.
write
=
[
'auto* it = std::copy_n(${result}.begin(), std::min(${result}.size(), ${name}_size - 1), ${name});'
'*it =
\'\\
0
\'
'
]
class
Handle
:
class
Handle
:
def
__init__
(
self
,
def
__init__
(
self
,
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
**
kwargs
)
->
None
:
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
self
.
name
=
name
self
.
name
=
name
self
.
ctype
=
ctype
self
.
ctype
=
ctype
self
.
cpptype
=
cpptype
self
.
cpptype
=
cpptype
self
.
opaque_type
=
self
.
ctype
+
'_t'
self
.
cpp_class
=
CPPClass
(
name
,
ctype
)
self
.
cpp_class
=
CPPClass
(
name
,
ctype
)
add_handle
(
name
,
ctype
,
cpptype
,
ref
=
ref
)
add_handle
(
name
,
ctype
,
cpptype
,
**
kwargs
)
cpp_type_map
[
cpptype
]
=
name
cpp_type_map
[
cpptype
]
=
name
def
cname
(
self
,
name
:
str
)
->
str
:
def
cname
(
self
,
name
:
str
)
->
str
:
...
@@ -829,6 +960,7 @@ class Handle:
...
@@ -829,6 +960,7 @@ class Handle:
return
Template
(
s
).
safe_substitute
(
name
=
self
.
name
,
return
Template
(
s
).
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
ctype
=
self
.
ctype
,
cpptype
=
self
.
cpptype
,
cpptype
=
self
.
cpptype
,
opaque_type
=
self
.
opaque_type
,
**
kwargs
)
**
kwargs
)
def
constructor
(
self
,
def
constructor
(
self
,
...
@@ -883,6 +1015,137 @@ class Handle:
...
@@ -883,6 +1015,137 @@ class Handle:
cpp_classes
.
append
(
self
.
cpp_class
)
cpp_classes
.
append
(
self
.
cpp_class
)
interface_handle_definition
=
Template
(
'''
extern "C" struct ${ctype};
struct ${ctype} {
template<class... Ts>
${ctype}(void* p, ${copier} c, ${deleter} d, Ts&&... xs)
: object_ptr(p, c, d), xobject(std::forward<Ts>(xs)...)
{}
manage_generic_ptr<${copier}, ${deleter}> object_ptr = nullptr;
${cpptype} xobject;
${functions}
};
'''
)
c_api_virtual_impl
=
Template
(
'''
${return_type} ${name}(${params}) const
{
${output_decls}
if (${fname} == nullptr)
throw std::runtime_error("${name} function is missing.");
auto api_error_result = ${fname}(${args});
if (api_error_result != ${success})
throw std::runtime_error("Error in ${name}.");
return ${output};
}
'''
)
def
generate_virtual_impl
(
f
:
Function
,
fname
:
str
)
->
str
:
success
=
success_type
name
=
f
.
name
return_type
=
'void'
output_decls
=
''
output
=
''
largs
=
[]
lparams
=
[]
if
f
.
returns
:
return_type
=
f
.
returns
.
type
.
str
()
output_decls
=
'
\n
'
.
join
(
f
.
returns
.
virtual_output_declarations
())
largs
+=
f
.
returns
.
virtual_output_args
()
output
=
f
.
returns
.
virtual_output
()
largs
+=
[
arg
for
p
in
f
.
params
for
arg
in
p
.
virtual_arg
()]
lparams
+=
[
p
.
virtual_param
()
for
p
in
f
.
params
if
not
p
.
this
]
args
=
', '
.
join
(
largs
)
params
=
', '
.
join
(
lparams
)
return
c_api_virtual_impl
.
substitute
(
locals
())
class
Interface
(
Handle
):
def
__init__
(
self
,
name
:
str
,
ctype
:
str
,
cpptype
:
str
)
->
None
:
super
().
__init__
(
name
,
ctype
,
cpptype
,
skip_def
=
True
)
self
.
ifunctions
:
List
[
Function
]
=
[]
self
.
members
:
List
[
str
]
=
[]
def
mname
(
self
,
name
:
str
)
->
str
:
return
name
+
"_f"
def
constructor
(
# type: ignore
self
,
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
**
kwargs
)
->
'Interface'
:
create
=
self
.
substitute
(
'allocate<${opaque_type}>($@)'
)
initial_params
=
gparams
(
obj
=
'void*'
,
c
=
self
.
cname
(
'copy'
),
d
=
self
.
cname
(
'delete'
))
add_function
(
self
.
cname
(
name
),
params
=
initial_params
+
(
params
or
[]),
invoke
=
create
,
returns
=
self
.
opaque_type
,
return_name
=
self
.
name
,
**
kwargs
)
return
self
def
method
(
self
,
*
args
,
**
kwargs
)
->
'Interface'
:
super
().
method
(
*
args
,
**
kwargs
)
return
self
def
virtual
(
self
,
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
const
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
'Interface'
:
# Add this parameter to the function
this
=
Parameter
(
'obj'
,
'void*'
,
this
=
True
)
this
.
virtual_read
=
[
'object_ptr.data'
]
f
=
Function
(
name
,
params
=
[
this
]
+
(
params
or
[]),
virtual
=
True
,
**
kwargs
)
self
.
ifunctions
.
append
(
f
)
add_function
(
self
.
cname
(
'set_'
+
name
),
params
=
gparams
(
obj
=
self
.
opaque_type
,
input
=
self
.
cname
(
name
)),
invoke
=
'${{obj}}->{name} = ${{input}}'
.
format
(
name
=
self
.
mname
(
name
)))
return
self
def
generate_function
(
self
,
f
:
Function
):
cname
=
self
.
cname
(
f
.
name
)
mname
=
self
.
mname
(
f
.
name
)
function
=
generate_virtual_impl
(
f
,
fname
=
mname
)
return
f
"
{
cname
}
{
mname
}
= nullptr;
{
function
}
"
def
generate
(
self
):
required_functions
=
[
Function
(
'copy'
,
params
=
gparams
(
out
=
'void**'
,
input
=
'void*'
),
virtual
=
True
),
Function
(
'delete'
,
params
=
gparams
(
input
=
'void*'
),
virtual
=
True
)
]
for
f
in
self
.
ifunctions
+
required_functions
:
f
.
update
()
c_header_preamble
.
extend
([
f
.
get_cfunction
().
generate_function_pointer
(
self
.
cname
(
f
.
name
))
for
f
in
self
.
ifunctions
+
required_functions
])
function_list
=
[
self
.
generate_function
(
f
)
for
f
in
self
.
ifunctions
]
ctype
=
self
.
ctype
cpptype
=
self
.
cpptype
copier
=
self
.
cname
(
'copy'
)
deleter
=
self
.
cname
(
'delete'
)
functions
=
'
\n
'
.
join
(
function_list
)
c_api_body_preamble
.
append
(
interface_handle_definition
.
substitute
(
locals
()))
def
handle
(
ctype
:
str
,
def
handle
(
ctype
:
str
,
cpptype
:
str
,
cpptype
:
str
,
name
:
Optional
[
str
]
=
None
,
name
:
Optional
[
str
]
=
None
,
...
@@ -902,6 +1165,23 @@ def handle(ctype: str,
...
@@ -902,6 +1165,23 @@ def handle(ctype: str,
return
with_handle
return
with_handle
def
interface
(
ctype
:
str
,
cpptype
:
str
,
name
:
Optional
[
str
]
=
None
)
->
Callable
:
def
with_interface
(
f
):
n
=
name
or
f
.
__name__
h
=
Interface
(
n
,
ctype
,
cpptype
)
f
(
h
)
h
.
generate
()
@
wraps
(
f
)
def
decorated
(
*
args
,
**
kwargs
):
return
f
(
*
args
,
**
kwargs
)
return
decorated
return
with_interface
def
template_eval
(
template
,
**
kwargs
):
def
template_eval
(
template
,
**
kwargs
):
start
=
'<%'
start
=
'<%'
end
=
'%>'
end
=
'%>'
...
@@ -924,7 +1204,7 @@ def run(args: List[str]) -> None:
...
@@ -924,7 +1204,7 @@ def run(args: List[str]) -> None:
else
:
else
:
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_header
())
sys
.
stdout
.
write
(
generate_c_api_body
())
sys
.
stdout
.
write
(
generate_c_api_body
())
sys
.
stdout
.
write
(
generate_cpp_header
())
#
sys.stdout.write(generate_cpp_header())
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tools/api/api.cpp
100755 → 100644
View file @
11e155c2
...
@@ -4,12 +4,14 @@
...
@@ -4,12 +4,14 @@
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/quantization.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <algorithm>
#include <algorithm>
...
@@ -72,6 +74,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
...
@@ -72,6 +74,23 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Unknown type"
);
MIGRAPHX_THROW
(
migraphx_status_bad_param
,
"Unknown type"
);
}
}
template
<
class
T
>
auto
to_obj_vector
(
const
T
*
x
,
std
::
size_t
n
)
{
std
::
vector
<
decltype
((
*
x
)
->
object
)
>
result
;
std
::
transform
(
x
,
x
+
n
,
std
::
back_inserter
(
result
),
[
&
](
auto
&&
y
)
{
return
y
->
object
;
});
return
result
;
}
template
<
class
T
,
class
U
>
auto
to_objptr_vector
(
const
U
*
x
,
std
::
size_t
n
)
{
std
::
vector
<
T
>
result
;
std
::
transform
(
x
,
x
+
n
,
std
::
back_inserter
(
result
),
[
&
](
auto
&&
y
)
{
return
std
::
addressof
(
y
->
object
);
});
return
result
;
}
target
get_target
(
const
std
::
string
&
name
)
{
return
make_target
(
name
);
}
target
get_target
(
const
std
::
string
&
name
)
{
return
make_target
(
name
);
}
void
set_offload_copy
(
compile_options
&
options
,
bool
value
)
{
options
.
offload_copy
=
value
;
}
void
set_offload_copy
(
compile_options
&
options
,
bool
value
)
{
options
.
offload_copy
=
value
;
}
...
@@ -194,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; }
...
@@ -194,6 +213,41 @@ void print_program(const program& p) { std::cout << p << std::endl; }
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
void
print_module
(
const
module
&
m
)
{
std
::
cout
<<
m
<<
std
::
endl
;
}
struct
experimental_custom_op
{
std
::
string
name
;
experimental_custom_op
()
=
default
;
experimental_custom_op
(
std
::
string
pname
)
:
name
(
std
::
move
(
pname
))
{}
};
template
<
class
CustomOp
>
struct
custom_operation
{
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
,
F
)
{
return
pack
();
}
CustomOp
op
;
std
::
string
name
()
const
{
return
op
.
xobject
.
name
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
return
op
.
compute_shape
(
std
::
move
(
inputs
));
}
argument
compute
(
const
std
::
vector
<
argument
>&
)
const
{
MIGRAPHX_THROW
(
"Not computable"
);
}
};
template
<
class
CustomOp
>
void
register_custom_op
(
const
CustomOp
&
op
)
{
register_op
(
custom_operation
<
CustomOp
>
{
op
});
}
migraphx
::
context
get_context
(
const
program
&
p
)
{
return
p
.
get_context
();
}
}
// namespace migraphx
}
// namespace migraphx
<%
generate_c_api_body
()
%>
<%
generate_c_api_body
()
%>
tools/api/migraphx.h
100755 → 100644
View file @
11e155c2
...
@@ -25,7 +25,8 @@ extern "C" {
...
@@ -25,7 +25,8 @@ extern "C" {
#endif
#endif
// return code, more to be added later
// return code, more to be added later
typedef
enum
{
typedef
enum
{
migraphx_status_success
=
0
,
migraphx_status_success
=
0
,
migraphx_status_bad_param
=
1
,
migraphx_status_bad_param
=
1
,
migraphx_status_unknown_target
=
3
,
migraphx_status_unknown_target
=
3
,
...
@@ -35,7 +36,8 @@ typedef enum {
...
@@ -35,7 +36,8 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
/// An enum to represent the different data type inputs
typedef
enum
{
typedef
enum
{
migraphx_shape_tuple_type
,
migraphx_shape_tuple_type
,
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
}
migraphx_shape_datatype_t
;
}
migraphx_shape_datatype_t
;
...
...
tools/generate.sh
View file @
11e155c2
...
@@ -7,11 +7,13 @@ fi
...
@@ -7,11 +7,13 @@ fi
if
type
-p
python3.8
>
/dev/null
;
then
if
type
-p
python3.8
>
/dev/null
;
then
PYTHON
=
python3.8
PYTHON
=
python3.8
fi
fi
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} | clang-format-
5.
0 -style=file >
$SRC_DIR
/include/migraphx/{}"
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} | clang-format-
1
0 -style=file >
$SRC_DIR
/include/migraphx/{}"
function
api
{
function
api
{
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-
5.
0
-style
=
file
>
$2
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-
1
0
-style
=
file
>
$2
}
}
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
echo
"Finished generating header migraphx.h"
api
$DIR
/api/api.cpp
$SRC_DIR
/api/api.cpp
api
$DIR
/api/api.cpp
$SRC_DIR
/api/api.cpp
echo
"Finished generating source api.cpp "
tools/include/context.hpp
View file @
11e155c2
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <utility>
#include <utility>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/value.hpp>
#include <migraphx/any_ptr.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -33,12 +34,21 @@ value to_value_context(const T&)
...
@@ -33,12 +34,21 @@ value to_value_context(const T&)
}
}
template
<
class
T
>
template
<
class
T
>
void
from_value_context
(
T
&
,
const
value
&
){}
void
from_value_context
(
T
&
,
const
value
&
)
{
}
template
<
class
T
>
any_ptr
get_queue_context
(
T
&
)
{
return
{};
}
<%
<%
interface
(
'
context
'
,
interface
(
'
context
'
,
virtual
(
'
to_value
'
,
returns
=
'
value
'
,
const
=
True
,
default
=
'
to_value_context
'
),
virtual
(
'
to_value
'
,
returns
=
'
value
'
,
const
=
True
,
default
=
'
to_value_context
'
),
virtual
(
'
from_value
'
,
v
=
'
const
value
&
'
,
default
=
'
from_value_context
'
),
virtual
(
'
from_value
'
,
v
=
'
const
value
&
'
,
default
=
'
from_value_context
'
),
virtual
(
'
get_queue
'
,
returns
=
'
any_ptr
'
,
default
=
'
get_queue_context
'
),
virtual
(
'
finish
'
,
returns
=
'
void
'
,
const
=
True
))
%>
virtual
(
'
finish
'
,
returns
=
'
void
'
,
const
=
True
))
%>
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
inline
void
migraphx_to_value
(
value
&
v
,
const
context
&
ctx
)
...
...
tools/include/schedule_model.hpp
View file @
11e155c2
...
@@ -26,11 +26,11 @@ struct schedule_model
...
@@ -26,11 +26,11 @@ struct schedule_model
/// Get the number of concurrent instruction allowed
/// Get the number of concurrent instruction allowed
std
::
size_t
concurrency
()
const
;
std
::
size_t
concurrency
()
const
;
/// Schedule a concurrent instruction
/// Schedule a concurrent instruction
void
sched
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
void
sched
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
n
)
const
;
// Insert necessary waits before an instruction
// Insert necessary waits before an instruction
void
wait
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
wait
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
// Insert necessary records after an instruction
// Insert necessary records after an instruction
void
record
(
module
&
p
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
void
record
(
module
&
m
,
instruction_ref
ins
,
std
::
size_t
wait_id
)
const
;
/// Compute weights for an operation
/// Compute weights for an operation
std
::
size_t
weight
(
const
operation
&
op
)
const
;
std
::
size_t
weight
(
const
operation
&
op
)
const
;
};
};
...
@@ -40,9 +40,9 @@ struct schedule_model
...
@@ -40,9 +40,9 @@ struct schedule_model
<%
<%
interface
(
'
schedule_model
'
,
interface
(
'
schedule_model
'
,
virtual
(
'
concurrency
'
,
returns
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
concurrency
'
,
returns
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
sched
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
n
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
sched
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
n
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
wait
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
wait
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
record
'
,
p
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
record
'
,
m
=
'
module
&
'
,
ins
=
'
instruction_ref
'
,
wait_id
=
'
std
::
size_t
'
,
const
=
True
),
virtual
(
'
weight
'
,
returns
=
'
std
::
size_t
'
,
op
=
'
const
operation
&
'
,
const
=
True
)
virtual
(
'
weight
'
,
returns
=
'
std
::
size_t
'
,
op
=
'
const
operation
&
'
,
const
=
True
)
)
)
%>
%>
...
...
tools/install_prereqs.sh
View file @
11e155c2
...
@@ -4,12 +4,20 @@
...
@@ -4,12 +4,20 @@
set
-e
set
-e
#install pip3, rocm-cmake, rocblas and miopen
export
LC_ALL
=
C.UTF-8
apt update
&&
apt
install
-y
python3-pip rocm-cmake rocblas miopen-hip openmp-extras
export
LANG
=
C.UTF-8
# Need pip3 and Python headers to build dependencies
apt update
&&
apt
install
-y
python3-pip python3-dev cmake rocm-cmake rocblas miopen-hip openmp-extras
# Needed for cmake to build various pip packages
pip3
install
setuptools wheel
# install rbuild to build dependencies
# install rbuild to build dependencies
pip3
install
https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
pip3
install
https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
PREFIX
=
/usr/local
PREFIX
=
/usr/local
REQ_FILE_DIR
=
""
REQ_FILE_DIR
=
""
if
[
"$#"
-ge
2
]
;
then
if
[
"$#"
-ge
2
]
;
then
...
@@ -19,7 +27,7 @@ elif [ "$#" -eq 1 ]; then
...
@@ -19,7 +27,7 @@ elif [ "$#" -eq 1 ]; then
PREFIX
=
$1
PREFIX
=
$1
fi
fi
echo
"Dependencies are install at
$PREFIX
"
echo
"Dependencies are install
ed
at
$PREFIX
"
# Install deps with rbuild
# Install deps with rbuild
rbuild prepare
-d
$PREFIX
-s
develop
rbuild prepare
-d
$PREFIX
-s
develop
...
@@ -27,3 +35,5 @@ rbuild prepare -d $PREFIX -s develop
...
@@ -27,3 +35,5 @@ rbuild prepare -d $PREFIX -s develop
# install onnx package for unit tests
# install onnx package for unit tests
pip3
install
onnx
==
1.8.1
numpy
==
1.18.5
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
16.8
pip3
install
onnx
==
1.8.1
numpy
==
1.18.5
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
16.8
# pin version of protobuf in Python for onnx runtime unit tests
pip3
install
protobuf
==
3.20.0
tools/te.py
View file @
11e155c2
...
@@ -12,16 +12,15 @@ headers = '''
...
@@ -12,16 +12,15 @@ headers = '''
'''
'''
form
=
string
.
Template
(
'''
form
=
string
.
Template
(
'''
#ifdef TYPE_ERASED_DECLARATION
/*
// Type-erased interface for:
* Type-erased interface for:
struct ${struct_name}
*
{
* struct ${struct_name}
${decl_members}
* {
};
${comment_members}
* };
#else
*
*/
struct ${struct_name}
struct ${struct_name}
{
{
...
@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
...
@@ -189,6 +188,7 @@ inline const ValueType & any_cast(const ${struct_name} & x)
if (y == nullptr) throw std::bad_cast();
if (y == nullptr) throw std::bad_cast();
return *y;
return *y;
}
}
#endif
'''
)
'''
)
nonvirtual_member
=
string
.
Template
(
'''
nonvirtual_member
=
string
.
Template
(
'''
...
@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
...
@@ -214,6 +214,10 @@ ${return_type} ${internal_name}(${member_params}) ${member_const} override
comment_member
=
string
.
Template
(
comment_member
=
string
.
Template
(
'''* ${friend} ${return_type} ${name}(${params}) ${const};'''
)
'''* ${friend} ${return_type} ${name}(${params}) ${const};'''
)
decl_member
=
string
.
Template
(
''' ${comment}
${friend} ${return_type} ${name}(${params}) ${const};
'''
)
default_member
=
string
.
Template
(
'''
default_member
=
string
.
Template
(
'''
template<class T>
template<class T>
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
static auto private_detail_te_default_${name}(char, T&& private_detail_te_self ${comma} ${member_params})
...
@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
...
@@ -279,7 +283,8 @@ def convert_member(d, struct_name):
'this'
:
'(*this)'
,
'this'
:
'(*this)'
,
'using'
:
''
,
'using'
:
''
,
'brief'
:
''
,
'brief'
:
''
,
'return_'
:
''
'return_'
:
''
,
'comment'
:
'// '
}
}
args
=
[]
args
=
[]
params
=
[]
params
=
[]
...
@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
...
@@ -306,6 +311,7 @@ def convert_member(d, struct_name):
member
[
'friend'
]
=
'friend'
member
[
'friend'
]
=
'friend'
elif
x
==
'default'
:
elif
x
==
'default'
:
member
[
'default'
]
=
t
member
[
'default'
]
=
t
member
[
'comment'
]
=
member
[
'comment'
]
+
'(optional)'
elif
x
==
'using'
:
elif
x
==
'using'
:
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
member
[
'using'
]
=
'using {};'
.
format
(
d
[
name
][
'using'
])
elif
x
==
'__brief__'
:
elif
x
==
'__brief__'
:
...
@@ -347,18 +353,21 @@ def generate_form(name, members):
...
@@ -347,18 +353,21 @@ def generate_form(name, members):
virtual_members
=
[]
virtual_members
=
[]
comment_members
=
[]
comment_members
=
[]
default_members
=
[]
default_members
=
[]
decl_members
=
[]
for
member
in
members
:
for
member
in
members
:
m
=
convert_member
(
member
,
name
)
m
=
convert_member
(
member
,
name
)
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
nonvirtual_members
.
append
(
nonvirtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
pure_virtual_members
.
append
(
pure_virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
virtual_members
.
append
(
virtual_member
.
substitute
(
m
))
comment_members
.
append
(
comment_member
.
substitute
(
m
))
comment_members
.
append
(
comment_member
.
substitute
(
m
))
decl_members
.
append
(
decl_member
.
substitute
(
m
))
if
'default'
in
m
:
if
'default'
in
m
:
default_members
.
append
(
default_member
.
substitute
(
m
))
default_members
.
append
(
default_member
.
substitute
(
m
))
return
form
.
substitute
(
nonvirtual_members
=
''
.
join
(
nonvirtual_members
),
return
form
.
substitute
(
nonvirtual_members
=
''
.
join
(
nonvirtual_members
),
pure_virtual_members
=
''
.
join
(
pure_virtual_members
),
pure_virtual_members
=
''
.
join
(
pure_virtual_members
),
virtual_members
=
''
.
join
(
virtual_members
),
virtual_members
=
''
.
join
(
virtual_members
),
default_members
=
''
.
join
(
default_members
),
default_members
=
''
.
join
(
default_members
),
decl_members
=
''
.
join
(
decl_members
),
comment_members
=
'
\n
'
.
join
(
comment_members
),
comment_members
=
'
\n
'
.
join
(
comment_members
),
struct_name
=
name
)
struct_name
=
name
)
...
...
Prev
1
…
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment