Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7f65a88e
Commit
7f65a88e
authored
Feb 04, 2022
by
Paul
Browse files
Merge branch 'develop' into mlir-c
parents
79bfe69f
b20e3d4d
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
331 additions
and
154 deletions
+331
-154
test/ref_ops_nonstd_shape_test.cpp
test/ref_ops_nonstd_shape_test.cpp
+58
-0
test/verify/run_verify.cpp
test/verify/run_verify.cpp
+4
-0
test/verify/test_arg_ops.cpp
test/verify/test_arg_ops.cpp
+76
-18
test/verify/test_conv_bias_clipped_relu.cpp
test/verify/test_conv_bias_clipped_relu.cpp
+2
-2
tools/api.py
tools/api.py
+182
-132
tools/generate.sh
tools/generate.sh
+9
-2
No files found.
test/ref_ops_nonstd_shape_test.cpp
0 → 100644
View file @
7f65a88e
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ref/target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/pass_manager.hpp>
#include "test.hpp"
TEST_CASE
(
argmax_test_nonstd_shape
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl_trans
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
dl
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmax"
,
{{
"axis"
,
-
3
}}),
dl_trans
);
auto
p_uncompiled
=
p
;
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
res_gold
=
p_uncompiled
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
}
TEST_CASE
(
argmin_test_nonstd_shape
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
data
=
{
1.2255
,
1.6834
,
-
2.0305
,
-
0.3221
,
0.4701
,
0.2583
,
0.7545
,
2.5758
,
-
1.6849
,
0.0928
,
0.9022
,
-
0.8765
,
-
0.4090
,
0.9301
,
2.0724
,
-
1.5706
,
0.4867
,
-
0.1493
,
0.6957
,
-
0.2179
,
0.7142
,
0.7177
,
0.0183
,
1.3497
};
migraphx
::
shape
data_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
}};
auto
dl
=
mm
->
add_literal
(
migraphx
::
literal
{
data_shape
,
data
});
auto
dl_trans
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
2
,
0
}}}),
dl
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"argmin"
,
{{
"axis"
,
-
1
}}),
dl_trans
);
auto
p_uncompiled
=
p
;
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
auto
res_gold
=
p_uncompiled
.
eval
({}).
back
();
std
::
vector
<
int64_t
>
result_vec
;
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/run_verify.cpp
View file @
7f65a88e
...
...
@@ -6,6 +6,7 @@
#include <migraphx/ref/target.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/load_save.hpp>
#include <migraphx/verify_args.hpp>
#include <set>
...
...
@@ -15,6 +16,7 @@
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_TEST_COMPILE
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_TEST
)
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DUMP_TEST
)
// An improved async, that doesn't block
template
<
class
Function
>
...
...
@@ -125,6 +127,8 @@ void run_verify::verify(const std::string& name, const migraphx::program& p) con
using
result_future
=
std
::
future
<
std
::
pair
<
migraphx
::
program
,
std
::
vector
<
migraphx
::
argument
>>>
;
auto_print
::
set_terminate_handler
(
name
);
if
(
migraphx
::
enabled
(
MIGRAPHX_DUMP_TEST
{}))
migraphx
::
save
(
p
,
name
+
".mx"
);
std
::
vector
<
std
::
pair
<
std
::
string
,
result_future
>>
results
;
std
::
vector
<
std
::
string
>
target_names
;
for
(
const
auto
&
tname
:
migraphx
::
get_targets
())
...
...
test/verify/test_arg_ops.cpp
100755 → 100644
View file @
7f65a88e
...
...
@@ -2,34 +2,92 @@
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/argmax.hpp>
#include <migraphx/op/argmin.hpp>
template
<
class
T
,
int
Axis
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
>>
template
<
class
T
,
int
Axis
,
int
NonStdShape
>
struct
test_arg_ops
:
verify_program
<
test_arg_ops
<
T
,
Axis
,
NonStdShape
>>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
1025
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
1
,
4
,
1025
}};
auto
param
=
mm
->
add_parameter
(
"data"
,
s
);
switch
(
NonStdShape
)
{
case
0
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
param
);
break
;
case
1
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
4
,
1025
}}}),
param
);
break
;
case
2
:
param
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
3
}}}),
param
);
break
;
default:
break
;
}
mm
->
add_instruction
(
T
{
Axis
},
param
);
return
p
;
}
};
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
>;
// transpose argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
0
>;
// transpose argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
0
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
0
>;
// broadcast argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
1
>;
// broadcast argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
1
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
1
>;
// slice argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
2
>;
// slice argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
2
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
2
>;
// default case, standard shape argmax tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmax
,
-
2
,
3
>;
// default case, standard shape argmin tests
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
0
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
1
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
2
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
3
,
3
>;
template
struct
test_arg_ops
<
migraphx
::
op
::
argmin
,
-
4
,
3
>;
test/verify/test_conv_bias_clipped_relu.cpp
View file @
7f65a88e
...
...
@@ -28,9 +28,9 @@ struct test_conv_bias_clipped_relu : verify_program<test_conv_bias_clipped_relu>
auto
min_val
=
mm
->
add_literal
(
0.0
f
);
auto
max_val
=
mm
->
add_literal
(
6.0
f
);
min_val
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_
lens
}}),
min_val
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
conv
->
get_shape
().
lens
()
}}),
min_val
);
max_val
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input_
lens
}}),
max_val
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
conv
->
get_shape
().
lens
()
}}),
max_val
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"clip"
),
bias_add
,
min_val
,
max_val
);
return
p
;
}
...
...
tools/api.py
View file @
7f65a88e
import
string
,
sys
,
re
,
runpy
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
type_map
=
{}
cpp_type_map
=
{}
functions
=
[]
cpp_classes
=
[]
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
functions
:
List
[
'Function'
]
=
[]
cpp_classes
:
List
[
'CPPClass'
]
=
[]
error_type
=
''
success_type
=
''
try_wrap
=
''
c_header_preamble
=
[]
c_api_body_preamble
=
[]
cpp_header_preamble
=
[]
c_header_preamble
:
List
[
str
]
=
[]
c_api_body_preamble
:
List
[
str
]
=
[]
cpp_header_preamble
:
List
[
str
]
=
[]
def
bad_param_error
(
msg
):
...
...
@@ -23,31 +24,31 @@ class Template(string.Template):
class
Type
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
.
strip
()
def
is_pointer
(
self
):
def
is_pointer
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'*'
)
def
is_reference
(
self
):
def
is_reference
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'&'
)
def
is_const
(
self
):
def
is_const
(
self
)
->
bool
:
return
self
.
name
.
startswith
(
'const '
)
def
is_variadic
(
self
):
return
self
.
name
.
startswith
(
'...'
)
def
add_pointer
(
self
):
def
add_pointer
(
self
)
->
'Type'
:
return
Type
(
self
.
name
+
'*'
)
def
add_reference
(
self
):
return
Type
(
self
.
name
+
'&'
)
def
add_const
(
self
):
def
add_const
(
self
)
->
'Type'
:
return
Type
(
'const '
+
self
.
name
)
def
inner_type
(
self
):
def
inner_type
(
self
)
->
Optional
[
'Type'
]
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -55,7 +56,7 @@ class Type:
else
:
return
None
def
remove_generic
(
self
):
def
remove_generic
(
self
)
->
'Type'
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -63,25 +64,25 @@ class Type:
else
:
return
self
def
remove_pointer
(
self
):
def
remove_pointer
(
self
)
->
'Type'
:
if
self
.
is_pointer
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_reference
(
self
):
def
remove_reference
(
self
)
->
'Type'
:
if
self
.
is_reference
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_const
(
self
):
def
remove_const
(
self
)
->
'Type'
:
if
self
.
is_const
():
return
Type
(
self
.
name
[
6
:])
return
self
def
basic
(
self
):
def
basic
(
self
)
->
'Type'
:
return
self
.
remove_pointer
().
remove_const
().
remove_reference
()
def
decay
(
self
):
def
decay
(
self
)
->
'Type'
:
t
=
self
.
remove_reference
()
if
t
.
is_pointer
():
return
t
...
...
@@ -93,7 +94,7 @@ class Type:
return
self
.
add_const
()
return
self
def
str
(
self
):
def
str
(
self
)
->
str
:
return
self
.
name
...
...
@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params})
class
CFunction
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
self
.
params
=
[]
self
.
body
=
[]
self
.
va_start
=
[]
self
.
va_end
=
[]
self
.
params
:
List
[
str
]
=
[]
self
.
body
:
List
[
str
]
=
[]
self
.
va_start
:
List
[
str
]
=
[]
self
.
va_end
:
List
[
str
]
=
[]
def
add_param
(
self
,
type
,
pname
)
:
def
add_param
(
self
,
type
:
str
,
pname
:
str
)
->
None
:
self
.
params
.
append
(
'{} {}'
.
format
(
type
,
pname
))
def
add_statement
(
self
,
stmt
)
:
def
add_statement
(
self
,
stmt
:
str
)
->
None
:
self
.
body
.
append
(
stmt
)
def
add_vlist
(
self
,
name
)
:
def
add_vlist
(
self
,
name
:
str
)
->
None
:
last_param
=
self
.
params
[
-
1
].
split
()[
-
1
]
self
.
va_start
=
[
'va_list {};'
.
format
(
name
),
...
...
@@ -135,7 +136,7 @@ class CFunction:
self
.
va_end
=
[
'va_end({});'
.
format
(
name
)]
self
.
add_param
(
'...'
,
''
)
def
substitute
(
self
,
form
)
:
def
substitute
(
self
,
form
:
Template
)
->
str
:
return
form
.
substitute
(
error_type
=
error_type
,
try_wrap
=
try_wrap
,
name
=
self
.
name
,
...
...
@@ -144,25 +145,29 @@ class CFunction:
va_start
=
"
\n
"
.
join
(
self
.
va_start
),
va_end
=
"
\n
"
.
join
(
self
.
va_end
))
def
generate_header
(
self
):
def
generate_header
(
self
)
->
str
:
return
self
.
substitute
(
header_function
)
def
generate_body
(
self
):
def
generate_body
(
self
)
->
str
:
return
self
.
substitute
(
c_api_impl
)
class
BadParam
:
def
__init__
(
self
,
cond
,
msg
)
:
def
__init__
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
cond
=
cond
self
.
msg
=
msg
class
Parameter
:
def
__init__
(
self
,
name
,
type
,
optional
=
False
,
returns
=
False
):
def
__init__
(
self
,
name
:
str
,
type
:
str
,
optional
:
bool
=
False
,
returns
:
bool
=
False
)
->
None
:
self
.
name
=
name
self
.
type
=
Type
(
type
)
self
.
optional
=
optional
self
.
cparams
=
[]
self
.
cparams
:
List
[
Tuple
[
str
,
str
]]
=
[]
self
.
size_cparam
=
-
1
self
.
size_name
=
''
self
.
read
=
'${name}'
...
...
@@ -170,15 +175,15 @@ class Parameter:
self
.
cpp_read
=
'${name}'
self
.
cpp_write
=
'${name}'
self
.
returns
=
returns
self
.
bad_param_check
=
None
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
def
get_name
(
self
,
prefix
=
None
)
:
def
get_name
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
prefix
:
return
prefix
+
self
.
name
else
:
return
self
.
name
def
get_cpp_type
(
self
):
def
get_cpp_type
(
self
)
->
str
:
if
self
.
type
.
str
()
in
cpp_type_map
:
return
cpp_type_map
[
self
.
type
.
basic
().
str
()]
elif
self
.
type
.
basic
().
str
()
in
cpp_type_map
:
...
...
@@ -188,7 +193,10 @@ class Parameter:
else
:
return
self
.
type
.
str
()
def
substitute
(
self
,
s
,
prefix
=
None
,
result
=
None
):
def
substitute
(
self
,
s
:
str
,
prefix
:
Optional
[
str
]
=
None
,
result
:
Optional
[
str
]
=
None
)
->
str
:
ctype
=
None
if
len
(
self
.
cparams
)
>
0
:
ctype
=
Type
(
self
.
cparams
[
0
][
0
]).
basic
().
str
()
...
...
@@ -199,12 +207,13 @@ class Parameter:
size
=
self
.
size_name
,
result
=
result
or
''
)
def
add_param
(
self
,
t
,
name
=
None
):
def
add_param
(
self
,
t
:
Union
[
str
,
Type
],
name
:
Optional
[
str
]
=
None
)
->
None
:
if
not
isinstance
(
t
,
str
):
t
=
t
.
str
()
self
.
cparams
.
append
((
t
,
name
or
self
.
name
))
def
add_size_param
(
self
,
name
=
None
)
:
def
add_size_param
(
self
,
name
:
Optional
[
str
]
=
None
)
->
None
:
self
.
size_cparam
=
len
(
self
.
cparams
)
self
.
size_name
=
name
or
self
.
name
+
'_size'
if
self
.
returns
:
...
...
@@ -212,7 +221,7 @@ class Parameter:
else
:
self
.
add_param
(
'size_t'
,
self
.
size_name
)
def
bad_param
(
self
,
cond
,
msg
)
:
def
bad_param
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
bad_param_check
=
BadParam
(
cond
,
msg
)
def
remove_size_param
(
self
,
name
):
...
...
@@ -223,7 +232,7 @@ class Parameter:
self
.
size_name
=
name
return
p
def
update
(
self
):
def
update
(
self
)
->
None
:
t
=
self
.
type
.
basic
().
str
()
g
=
self
.
type
.
remove_generic
().
basic
().
str
()
if
t
in
type_map
:
...
...
@@ -239,18 +248,18 @@ class Parameter:
raise
ValueError
(
"Error for {}: write cannot be a string"
.
format
(
self
.
type
.
str
()))
def
cpp_param
(
self
,
prefix
=
None
)
:
def
cpp_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
'${cpptype} ${name}'
,
prefix
=
prefix
)
def
cpp_arg
(
self
,
prefix
=
None
)
:
def
cpp_arg
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_read
,
prefix
=
prefix
)
def
cpp_output_args
(
self
,
prefix
=
None
)
:
def
cpp_output_args
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'&{prefix}{n}'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
output_declarations
(
self
,
prefix
=
None
)
:
def
output_declarations
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'{type} {prefix}{n};'
.
format
(
type
=
Type
(
t
).
remove_pointer
().
str
(),
prefix
=
prefix
,
...
...
@@ -262,16 +271,16 @@ class Parameter:
'&{prefix}{n};'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
cpp_output
(
self
,
prefix
=
None
)
:
def
cpp_output
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_write
,
prefix
=
prefix
)
def
input
(
self
,
prefix
=
None
)
:
def
input
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
'('
+
self
.
substitute
(
self
.
read
,
prefix
=
prefix
)
+
')'
def
outputs
(
self
,
result
=
None
)
:
def
outputs
(
self
,
result
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
self
.
substitute
(
w
,
result
=
result
)
for
w
in
self
.
write
]
def
add_to_cfunction
(
self
,
cfunction
)
:
def
add_to_cfunction
(
self
,
cfunction
:
CFunction
)
->
None
:
for
t
,
name
in
self
.
cparams
:
if
t
.
startswith
(
'...'
):
cfunction
.
add_vlist
(
name
)
...
...
@@ -285,35 +294,35 @@ class Parameter:
body
=
bad_param_error
(
msg
)))
def
template_var
(
s
)
:
def
template_var
(
s
:
str
)
->
str
:
return
'${'
+
s
+
'}'
def
to_template_vars
(
params
)
:
def
to_template_vars
(
params
:
List
[
Union
[
Any
,
Parameter
]])
->
str
:
return
', '
.
join
([
template_var
(
p
.
name
)
for
p
in
params
])
class
Function
:
def
__init__
(
self
,
name
,
params
=
None
,
shared_size
=
False
,
returns
=
None
,
invoke
=
None
,
fname
=
None
,
return_name
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
shared_size
:
bool
=
False
,
returns
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
fname
:
Optional
[
str
]
=
None
,
return_name
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
None
:
self
.
name
=
name
self
.
params
=
params
or
[]
self
.
shared_size
=
False
self
.
cfunction
=
None
self
.
cfunction
:
Optional
[
CFunction
]
=
None
self
.
fname
=
fname
self
.
invoke
=
invoke
or
'${__fname__}($@)'
self
.
return_name
=
return_name
or
'out'
self
.
returns
=
Parameter
(
self
.
return_name
,
returns
,
returns
=
True
)
if
returns
else
None
def
share_params
(
self
):
def
share_params
(
self
)
->
None
:
if
self
.
shared_size
==
True
:
size_param_name
=
'size'
size_type
=
Type
(
'size_t'
)
...
...
@@ -323,7 +332,7 @@ class Function:
size_type
=
Type
(
p
[
0
])
self
.
params
.
append
(
Parameter
(
size_param_name
,
size_type
.
str
()))
def
update
(
self
):
def
update
(
self
)
->
None
:
self
.
share_params
()
for
param
in
self
.
params
:
param
.
update
()
...
...
@@ -331,11 +340,12 @@ class Function:
self
.
returns
.
update
()
self
.
create_cfunction
()
def
inputs
(
self
):
def
inputs
(
self
)
->
str
:
return
', '
.
join
([
p
.
input
()
for
p
in
self
.
params
])
def
input_map
(
self
):
m
=
{}
# TODO: Shoule we remove Optional?
def
input_map
(
self
)
->
Dict
[
str
,
Optional
[
str
]]:
m
:
Dict
[
str
,
Optional
[
str
]]
=
{}
for
p
in
self
.
params
:
m
[
p
.
name
]
=
p
.
input
()
m
[
'return'
]
=
self
.
return_name
...
...
@@ -343,14 +353,22 @@ class Function:
m
[
'__fname__'
]
=
self
.
fname
return
m
def
get_invoke
(
self
):
def
get_invoke
(
self
)
->
str
:
return
Template
(
self
.
invoke
).
safe_substitute
(
self
.
input_map
())
def
write_to_tmp_var
(
self
):
def
write_to_tmp_var
(
self
)
->
bool
:
if
not
self
.
returns
:
return
False
return
len
(
self
.
returns
.
write
)
>
1
or
self
.
returns
.
write
[
0
].
count
(
'${result}'
)
>
1
def
create_cfunction
(
self
):
def
get_cfunction
(
self
)
->
CFunction
:
if
self
.
cfunction
:
return
self
.
cfunction
raise
Exception
(
"self.cfunction is None: self.update() needs to be called."
)
def
create_cfunction
(
self
)
->
None
:
self
.
cfunction
=
CFunction
(
self
.
name
)
# Add the return as a parameter
if
self
.
returns
:
...
...
@@ -358,12 +376,12 @@ class Function:
# Add the input parameters
for
param
in
self
.
params
:
param
.
add_to_cfunction
(
self
.
cfunction
)
f
=
self
.
get_invoke
()
f
:
Optional
[
str
]
=
self
.
get_invoke
()
# Write the assignments
assigns
=
[]
if
self
.
returns
:
result
=
f
if
self
.
write_to_tmp_var
():
if
self
.
write_to_tmp_var
()
and
f
:
f
=
'auto&& api_result = '
+
f
result
=
'api_result'
else
:
...
...
@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template('''
class
CPPMember
:
def
__init__
(
self
,
name
,
function
,
prefix
,
method
=
True
):
def
__init__
(
self
,
name
:
str
,
function
:
Function
,
prefix
:
str
,
method
:
bool
=
True
)
->
None
:
self
.
name
=
name
self
.
function
=
function
self
.
prefix
=
prefix
self
.
method
=
method
def
get_function_params
(
self
):
def
get_function_params
(
self
)
->
List
[
Union
[
Any
,
Parameter
]]
:
if
self
.
method
:
return
self
.
function
.
params
[
1
:]
else
:
return
self
.
function
.
params
def
get_args
(
self
):
def
get_args
(
self
)
->
str
:
output_args
=
[]
if
self
.
function
.
returns
:
output_args
=
self
.
function
.
returns
.
cpp_output_args
(
self
.
prefix
)
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
', '
.
join
(
[
'&{}'
.
format
(
self
.
function
.
cfunction
.
name
)]
+
output_args
+
[
p
.
cpp_arg
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_params
(
self
):
def
get_params
(
self
)
->
str
:
return
', '
.
join
(
[
p
.
cpp_param
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_return_declarations
(
self
):
def
get_return_declarations
(
self
)
->
str
:
if
self
.
function
.
returns
:
return
'
\n
'
.
join
([
d
...
...
@@ -452,7 +476,9 @@ class CPPMember:
def
get_result
(
self
):
return
self
.
function
.
returns
.
input
(
self
.
prefix
)
def
generate_method
(
self
):
def
generate_method
(
self
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
if
self
.
function
.
returns
:
return_type
=
self
.
function
.
returns
.
get_cpp_type
()
return
cpp_class_method_template
.
safe_substitute
(
...
...
@@ -472,7 +498,9 @@ class CPPMember:
args
=
self
.
get_args
(),
success
=
success_type
)
def
generate_constructor
(
self
,
name
):
def
generate_constructor
(
self
,
name
:
str
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
cpp_class_constructor_template
.
safe_substitute
(
name
=
name
,
cfunction
=
self
.
function
.
cfunction
.
name
,
...
...
@@ -482,98 +510,101 @@ class CPPMember:
class
CPPClass
:
def
__init__
(
self
,
name
,
ctype
)
:
def
__init__
(
self
,
name
:
str
,
ctype
:
str
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
constructors
=
[]
self
.
methods
=
[]
self
.
constructors
:
List
[
CPPMember
]
=
[]
self
.
methods
:
List
[
CPPMember
]
=
[]
self
.
prefix
=
'p'
def
add_method
(
self
,
name
,
f
)
:
def
add_method
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
methods
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
add_constructor
(
self
,
name
,
f
)
:
def
add_constructor
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
constructors
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
generate_methods
(
self
):
def
generate_methods
(
self
)
->
str
:
return
'
\n
'
.
join
([
m
.
generate_method
()
for
m
in
self
.
methods
])
def
generate_constructors
(
self
):
def
generate_constructors
(
self
)
->
str
:
return
'
\n
'
.
join
(
[
m
.
generate_constructor
(
self
.
name
)
for
m
in
self
.
constructors
])
def
substitute
(
self
,
s
,
**
kwargs
):
t
=
s
if
isinstance
(
s
,
str
):
t
=
string
.
Template
(
s
)
def
substitute
(
self
,
s
:
Union
[
string
.
Template
,
str
],
**
kwargs
)
->
str
:
t
=
string
.
Template
(
s
)
if
isinstance
(
s
,
str
)
else
s
destroy
=
self
.
ctype
+
'_destroy'
return
t
.
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
destroy
=
destroy
,
**
kwargs
)
def
generate
(
self
):
def
generate
(
self
)
->
str
:
return
self
.
substitute
(
cpp_class_template
,
constructors
=
self
.
substitute
(
self
.
generate_constructors
()),
methods
=
self
.
substitute
(
self
.
generate_methods
()))
def
params
(
virtual
=
None
,
**
kwargs
):
def
params
(
virtual
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
**
kwargs
)
->
List
[
Parameter
]:
result
=
[]
for
name
in
virtual
or
{}:
result
.
append
(
Parameter
(
name
,
virtual
[
name
]))
v
:
Dict
[
str
,
str
]
=
virtual
or
{}
for
name
in
v
:
result
.
append
(
Parameter
(
name
,
v
[
name
]))
for
name
in
kwargs
:
result
.
append
(
Parameter
(
name
,
kwargs
[
name
]))
return
result
def
add_function
(
name
,
*
args
,
**
kwargs
):
def
add_function
(
name
:
str
,
*
args
,
**
kwargs
)
->
Function
:
f
=
Function
(
name
,
*
args
,
**
kwargs
)
functions
.
append
(
f
)
return
f
def
once
(
f
)
:
def
once
(
f
:
Callable
)
->
Any
:
@
wraps
(
f
)
def
decorated
(
*
args
,
**
kwargs
):
if
not
decorated
.
has_run
:
decorated
.
has_run
=
True
return
f
(
*
args
,
**
kwargs
)
decorated
.
has_run
=
False
return
decorated
d
:
Any
=
decorated
d
.
has_run
=
False
return
d
@
once
def
process_functions
():
def
process_functions
()
->
None
:
for
f
in
functions
:
f
.
update
()
def
generate_lines
(
p
)
:
def
generate_lines
(
p
:
List
[
str
])
->
str
:
return
'
\n
'
.
join
(
p
)
def
generate_c_header
():
def
generate_c_header
()
->
str
:
process_functions
()
return
generate_lines
(
c_header_preamble
+
[
f
.
cfunction
.
generate_header
()
for
f
in
functions
])
return
generate_lines
(
c_header_preamble
+
[
f
.
get_cfunction
().
generate_header
()
for
f
in
functions
])
def
generate_c_api_body
():
def
generate_c_api_body
()
->
str
:
process_functions
()
return
generate_lines
(
c_api_body_preamble
+
[
f
.
cfunction
.
generate_body
()
for
f
in
functions
])
return
generate_lines
(
c_api_body_preamble
+
[
f
.
get_cfunction
().
generate_body
()
for
f
in
functions
])
def
generate_cpp_header
():
def
generate_cpp_header
()
->
str
:
process_functions
()
return
generate_lines
(
cpp_header_preamble
+
[
c
.
generate
()
for
c
in
cpp_classes
])
def
cwrap
(
name
)
:
def
cwrap
(
name
:
str
)
->
Callable
:
def
with_cwrap
(
f
):
type_map
[
name
]
=
f
...
...
@@ -677,13 +708,17 @@ protected:
@
once
def
add_handle_preamble
():
def
add_handle_preamble
()
->
None
:
c_api_body_preamble
.
append
(
handle_preamble
)
cpp_header_preamble
.
append
(
string
.
Template
(
cpp_handle_preamble
).
substitute
(
success
=
success_type
))
def
add_handle
(
name
,
ctype
,
cpptype
,
destroy
=
None
,
ref
=
None
):
def
add_handle
(
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
destroy
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
opaque_type
=
ctype
+
'_t'
def
handle_wrap
(
p
):
...
...
@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None):
@
cwrap
(
'std::vector'
)
def
vector_c_wrap
(
p
):
t
=
p
.
type
.
inner_type
().
add_pointer
()
def
vector_c_wrap
(
p
:
Parameter
)
->
None
:
inner
=
p
.
type
.
inner_type
()
# Not a generic type
if
not
inner
:
return
t
=
inner
.
add_pointer
()
if
p
.
returns
:
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_const
():
...
...
@@ -747,7 +786,7 @@ def vector_c_wrap(p):
@
cwrap
(
'std::string'
)
def
string_c_wrap
(
p
)
:
def
string_c_wrap
(
p
:
Parameter
)
->
None
:
t
=
Type
(
'char*'
)
if
p
.
returns
:
if
p
.
type
.
is_reference
():
...
...
@@ -771,7 +810,11 @@ def string_c_wrap(p):
class
Handle
:
def
__init__
(
self
,
name
,
ctype
,
cpptype
,
ref
=
None
):
def
__init__
(
self
,
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
cpptype
=
cpptype
...
...
@@ -779,17 +822,21 @@ class Handle:
add_handle
(
name
,
ctype
,
cpptype
,
ref
=
ref
)
cpp_type_map
[
cpptype
]
=
name
def
cname
(
self
,
name
)
:
def
cname
(
self
,
name
:
str
)
->
str
:
return
self
.
ctype
+
'_'
+
name
def
substitute
(
self
,
s
,
**
kwargs
):
def
substitute
(
self
,
s
:
str
,
**
kwargs
)
->
str
:
return
Template
(
s
).
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
cpptype
=
self
.
cpptype
,
**
kwargs
)
def
constructor
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
**
kwargs
):
def
constructor
(
self
,
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
'Handle'
:
create
=
self
.
substitute
(
'allocate<${cpptype}>($@)'
)
if
fname
:
create
=
self
.
substitute
(
'allocate<${cpptype}>(${fname}($@))'
,
...
...
@@ -805,13 +852,13 @@ class Handle:
return
self
def
method
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
cpp_name
=
None
,
const
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
cpp_name
:
Optional
[
str
]
=
None
,
const
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
'Handle'
:
cpptype
=
self
.
cpptype
if
const
:
cpptype
=
Type
(
cpptype
).
add_const
().
str
()
...
...
@@ -832,11 +879,14 @@ class Handle:
add_function
(
self
.
cname
(
name
),
params
=
params
,
**
kwargs
)
return
self
def
add_cpp_class
(
self
):
def
add_cpp_class
(
self
)
->
None
:
cpp_classes
.
append
(
self
.
cpp_class
)
def
handle
(
ctype
,
cpptype
,
name
=
None
,
ref
=
None
):
def
handle
(
ctype
:
str
,
cpptype
:
str
,
name
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
Callable
:
def
with_handle
(
f
):
n
=
name
or
f
.
__name__
h
=
Handle
(
n
,
ctype
,
cpptype
,
ref
=
ref
)
...
...
@@ -865,10 +915,10 @@ def template_eval(template, **kwargs):
return
template
def
run
(
)
:
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
arg
v
)
>
2
:
f
=
open
(
sys
.
argv
[
2
]).
read
()
def
run
(
args
:
List
[
str
])
->
None
:
runpy
.
run_path
(
args
[
0
])
if
len
(
arg
s
)
>
1
:
f
=
open
(
args
[
1
]).
read
()
r
=
template_eval
(
f
)
sys
.
stdout
.
write
(
r
)
else
:
...
...
@@ -879,4 +929,4 @@ def run():
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
()
run
(
sys
.
argv
[
1
:]
)
tools/generate.sh
View file @
7f65a88e
DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SRC_DIR
=
$DIR
/../src
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"python3.6
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
PYTHON
=
python3
if
type
-p
python3.6
>
/dev/null
;
then
PYTHON
=
python3.6
fi
if
type
-p
python3.8
>
/dev/null
;
then
PYTHON
=
python3.8
fi
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
function
api
{
python3.6
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
}
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment