Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
08ac24cf
Commit
08ac24cf
authored
Feb 07, 2022
by
Khalique Ahmed
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into rocblas_api_opt
parents
96c82f21
b20e3d4d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
318 additions
and
232 deletions
+318
-232
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+4
-2
src/include/migraphx/iota_iterator.hpp
src/include/migraphx/iota_iterator.hpp
+7
-0
src/onnx/parse_resize.cpp
src/onnx/parse_resize.cpp
+11
-7
src/onnx/parse_upsample.cpp
src/onnx/parse_upsample.cpp
+0
-86
src/program.cpp
src/program.cpp
+72
-2
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+19
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+14
-1
test/onnx/upsample_linear_test.onnx
test/onnx/upsample_linear_test.onnx
+0
-0
tools/api.py
tools/api.py
+182
-132
tools/generate.sh
tools/generate.sh
+9
-2
No files found.
.github/workflows/ci.yaml
View file @
08ac24cf
...
...
@@ -142,10 +142,12 @@ jobs:
with
:
python-version
:
3.6
-
name
:
Install pyflakes
run
:
pip install pyflakes==2.3.1
run
:
pip install pyflakes==2.3.1
mypy==0.931
-
name
:
Run pyflakes
run
:
pyflakes examples/ tools/ src/ test/ doc/
run
:
|
pyflakes examples/ tools/ src/ test/ doc/
mypy tools/api.py
linux
:
...
...
src/include/migraphx/iota_iterator.hpp
View file @
08ac24cf
...
...
@@ -89,6 +89,13 @@ inline std::ptrdiff_t operator-(basic_iota_iterator<F, Iterator> x,
return
x
.
index
-
y
.
index
;
}
template
<
class
F
,
class
Iterator
>
inline
basic_iota_iterator
<
F
,
Iterator
>
operator
-
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
std
::
ptrdiff_t
y
)
{
return
x
-=
y
;
}
template
<
class
F
,
class
Iterator
>
inline
bool
operator
==
(
basic_iota_iterator
<
F
,
Iterator
>
x
,
basic_iota_iterator
<
F
,
Iterator
>
y
)
{
...
...
src/onnx/parse_resize.cpp
View file @
08ac24cf
...
...
@@ -163,9 +163,9 @@ static std::string get_nearest_mode(const onnx_parser::attribute_map& attr)
struct
parse_resize
:
op_parser
<
parse_resize
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Resize"
}
,
{
"Upsample"
}
};
}
instruction_ref
parse
(
const
op_desc
&
/*
opd
*/
,
instruction_ref
parse
(
const
op_desc
&
opd
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
...
...
@@ -183,7 +183,7 @@ struct parse_resize : op_parser<parse_resize>
if
(
contains
(
info
.
attributes
,
"exclude_outside"
)
and
info
.
attributes
.
at
(
"exclude_outside"
).
i
()
==
1
)
{
MIGRAPHX_THROW
(
"PARSE_
RESIZE
: exclude_outside 1 is not supported!"
);
MIGRAPHX_THROW
(
"PARSE_
"
+
opd
.
op_name
+
"
: exclude_outside 1 is not supported!"
);
}
// input data shape info
...
...
@@ -215,12 +215,14 @@ struct parse_resize : op_parser<parse_resize>
if
(
type
==
shape
::
int64_type
)
{
auto
arg_out_s
=
arg
->
eval
();
check_arg_empty
(
arg_out_s
,
"PARSE_RESIZE: dynamic output size is not supported!"
);
check_arg_empty
(
arg_out_s
,
"PARSE_"
+
opd
.
op_name
+
": dynamic output size is not supported!"
);
arg_out_s
.
visit
([
&
](
auto
ol
)
{
out_lens
.
assign
(
ol
.
begin
(),
ol
.
end
());
});
if
(
out_lens
.
size
()
!=
in_lens
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_RESIZE: specified output size does not match input size"
);
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": specified output size does not match input size"
);
}
// compute the scale
...
...
@@ -239,12 +241,14 @@ struct parse_resize : op_parser<parse_resize>
{
auto
arg_scale
=
arg
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_RESIZE: dynamic input scale is not supported!"
);
"PARSE_"
+
opd
.
op_name
+
": dynamic input scale is not supported!"
);
arg_scale
.
visit
([
&
](
auto
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_RESIZE: ranks of input and scale are different!"
);
MIGRAPHX_THROW
(
"PARSE_"
+
opd
.
op_name
+
": ranks of input and scale are different!"
);
}
std
::
transform
(
in_lens
.
begin
(),
...
...
src/onnx/parse_upsample.cpp
deleted
100644 → 0
View file @
96c82f21
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/onnx/checks.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_upsample
:
op_parser
<
parse_upsample
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"Upsample"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
onnx_parser
::
node_info
info
,
std
::
vector
<
instruction_ref
>
args
)
const
{
if
(
contains
(
info
.
attributes
,
"mode"
))
{
auto
mode
=
info
.
attributes
.
at
(
"mode"
).
s
();
if
(
mode
!=
"nearest"
)
{
MIGRAPHX_THROW
(
"PARSE_UPSAMPLE: only nearest mode is supported!"
);
}
}
auto
arg_scale
=
args
[
1
]
->
eval
();
check_arg_empty
(
arg_scale
,
"PARSE_UPSAMPLE: only constant scale is supported!"
);
std
::
vector
<
float
>
vec_scale
;
arg_scale
.
visit
([
&
](
auto
v
)
{
vec_scale
.
assign
(
v
.
begin
(),
v
.
end
());
});
auto
in_s
=
args
[
0
]
->
get_shape
();
auto
in_lens
=
in_s
.
lens
();
if
(
in_lens
.
size
()
!=
vec_scale
.
size
())
{
MIGRAPHX_THROW
(
"PARSE_UPSAMPLE: ranks of input and scale are different!"
);
}
std
::
vector
<
std
::
size_t
>
out_lens
(
in_lens
.
size
());
std
::
transform
(
in_lens
.
begin
(),
in_lens
.
end
(),
vec_scale
.
begin
(),
out_lens
.
begin
(),
[
&
](
auto
idx
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
idx
*
scale
);
});
std
::
vector
<
float
>
idx_scale
(
in_lens
.
size
());
std
::
transform
(
out_lens
.
begin
(),
out_lens
.
end
(),
in_lens
.
begin
(),
idx_scale
.
begin
(),
[](
auto
od
,
auto
id
)
{
return
(
od
==
id
)
?
1.0
f
:
(
id
-
1.0
f
)
/
(
od
-
1.0
f
);
});
shape
out_s
{
in_s
.
type
(),
out_lens
};
std
::
vector
<
int
>
ind
(
out_s
.
elements
());
// map out_idx to in_idx
shape_for_each
(
out_s
,
[
&
](
auto
idx
)
{
auto
in_idx
=
idx
;
std
::
transform
(
idx
.
begin
(),
idx
.
end
(),
idx_scale
.
begin
(),
in_idx
.
begin
(),
// nearest mode
[](
auto
index
,
auto
scale
)
{
return
static_cast
<
std
::
size_t
>
(
std
::
round
(
index
*
scale
));
});
ind
[
out_s
.
index
(
idx
)]
=
static_cast
<
int64_t
>
(
in_s
.
index
(
in_idx
));
});
// reshape input to one-dimension
std
::
vector
<
int64_t
>
rsp_lens
=
{
static_cast
<
int64_t
>
(
in_s
.
elements
())};
shape
ind_s
{
shape
::
int32_type
,
out_lens
};
auto
rsp
=
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
args
[
0
]);
auto
ins_ind
=
info
.
add_literal
(
literal
(
ind_s
,
ind
));
return
info
.
add_instruction
(
make_op
(
"gather"
,
{{
"axis"
,
0
}}),
rsp
,
ins_ind
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/program.cpp
View file @
08ac24cf
...
...
@@ -180,6 +180,63 @@ void program::finalize()
mm
->
finalize
(
this
->
impl
->
ctx
);
}
template
<
class
T
>
std
::
string
classify
(
T
x
)
{
switch
(
std
::
fpclassify
(
x
))
{
case
FP_INFINITE
:
return
"inf"
;
case
FP_NAN
:
return
"nan"
;
case
FP_NORMAL
:
return
"normal"
;
case
FP_SUBNORMAL
:
return
"subnormal"
;
case
FP_ZERO
:
return
"zero"
;
default:
return
"unknown"
;
}
}
std
::
unordered_set
<
std
::
string
>
classify_argument
(
const
argument
&
a
)
{
std
::
unordered_set
<
std
::
string
>
result
;
a
.
visit
(
[
&
](
auto
t
)
{
for
(
const
auto
&
x
:
t
)
result
.
insert
(
classify
(
x
));
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
auto
r
=
classify_argument
(
x
);
result
.
insert
(
r
.
begin
(),
r
.
end
());
}
});
return
result
;
}
void
preview_argument
(
std
::
ostream
&
os
,
const
argument
&
a
)
{
a
.
visit
(
[
&
](
auto
t
)
{
if
(
t
.
size
()
<=
10
)
{
os
<<
t
;
}
else
{
os
<<
to_string_range
(
t
.
begin
(),
t
.
begin
()
+
5
);
os
<<
", ..., "
;
os
<<
to_string_range
(
t
.
end
()
-
5
,
t
.
end
());
}
},
[
&
](
const
auto
&
xs
)
{
for
(
const
auto
&
x
:
xs
)
{
os
<<
'{'
;
preview_argument
(
os
,
x
);
os
<<
'}'
;
}
});
}
template
<
class
F
>
std
::
vector
<
argument
>
generic_eval
(
const
module
*
mod
,
context
&
ctx
,
...
...
@@ -312,8 +369,21 @@ std::vector<argument> program::eval(parameter_map params) const
if
(
trace_level
>
1
and
ins
->
name
().
front
()
!=
'@'
and
ins
->
name
()
!=
"load"
and
not
result
.
empty
())
{
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
std
::
cout
<<
"Output: "
<<
tgt
.
copy_from
(
result
)
<<
std
::
endl
;
target
tgt
=
make_target
(
this
->
impl
->
target_name
);
auto
buffer
=
tgt
.
copy_from
(
result
);
if
(
trace_level
==
2
)
{
std
::
cout
<<
"Output has "
<<
to_string_range
(
classify_argument
(
buffer
))
<<
std
::
endl
;
std
::
cout
<<
"Output: "
;
preview_argument
(
std
::
cout
,
buffer
);
std
::
cout
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"Output: "
<<
buffer
<<
std
::
endl
;
}
}
return
result
;
}));
...
...
test/onnx/gen_onnx.py
View file @
08ac24cf
...
...
@@ -5074,6 +5074,25 @@ def unknown_aten_test():
return
([
node
],
[
x
,
y
],
[
a
])
@
onnx_test
def
upsample_linear_test
():
scales
=
np
.
array
([
1.0
,
1.0
,
2.0
,
2.0
],
dtype
=
np
.
float32
)
scales_tensor
=
helper
.
make_tensor
(
name
=
'scales'
,
data_type
=
TensorProto
.
FLOAT
,
dims
=
scales
.
shape
,
vals
=
scales
.
flatten
().
astype
(
np
.
float32
))
X
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
2
,
2
])
Y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[])
node
=
onnx
.
helper
.
make_node
(
'Upsample'
,
inputs
=
[
'X'
,
''
,
'scales'
],
outputs
=
[
'Y'
],
mode
=
'linear'
)
return
([
node
],
[
X
],
[
Y
],
[
scales_tensor
])
@
onnx_test
def
upsample_test
():
scales
=
np
.
array
([
1.0
,
1.0
,
2.0
,
3.0
],
dtype
=
np
.
float32
)
...
...
test/onnx/onnx_test.cpp
View file @
08ac24cf
...
...
@@ -3643,7 +3643,7 @@ TEST_CASE(resize_nonstd_input_test)
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
resiz
e_upsample_linear_
ac_test
)
static
auto
creat
e_upsample_linear_
prog
(
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -3734,6 +3734,12 @@ TEST_CASE(resize_upsample_linear_ac_test)
auto
add1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
mul1
,
slc10
);
mm
->
add_return
({
add1
});
return
p
;
}
TEST_CASE
(
resize_upsample_linear_ac_test
)
{
auto
p
=
create_upsample_linear_prog
();
auto
prog
=
migraphx
::
parse_onnx
(
"resize_upsample_linear_ac_test.onnx"
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -4753,6 +4759,13 @@ TEST_CASE(unknown_test_throw)
EXPECT
(
test
::
throws
([
&
]
{
migraphx
::
parse_onnx
(
"unknown_test.onnx"
);
}));
}
TEST_CASE
(
upsample_linear_test
)
{
auto
p
=
create_upsample_linear_prog
();
auto
prog
=
migraphx
::
parse_onnx
(
"upsample_linear_test.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
upsample_test
)
{
migraphx
::
program
p
;
...
...
test/onnx/upsample_linear_test.onnx
0 → 100644
View file @
08ac24cf
File added
tools/api.py
View file @
08ac24cf
import
string
,
sys
,
re
,
runpy
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
type_map
=
{}
cpp_type_map
=
{}
functions
=
[]
cpp_classes
=
[]
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
functions
:
List
[
'Function'
]
=
[]
cpp_classes
:
List
[
'CPPClass'
]
=
[]
error_type
=
''
success_type
=
''
try_wrap
=
''
c_header_preamble
=
[]
c_api_body_preamble
=
[]
cpp_header_preamble
=
[]
c_header_preamble
:
List
[
str
]
=
[]
c_api_body_preamble
:
List
[
str
]
=
[]
cpp_header_preamble
:
List
[
str
]
=
[]
def
bad_param_error
(
msg
):
...
...
@@ -23,31 +24,31 @@ class Template(string.Template):
class
Type
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
.
strip
()
def
is_pointer
(
self
):
def
is_pointer
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'*'
)
def
is_reference
(
self
):
def
is_reference
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'&'
)
def
is_const
(
self
):
def
is_const
(
self
)
->
bool
:
return
self
.
name
.
startswith
(
'const '
)
def
is_variadic
(
self
):
return
self
.
name
.
startswith
(
'...'
)
def
add_pointer
(
self
):
def
add_pointer
(
self
)
->
'Type'
:
return
Type
(
self
.
name
+
'*'
)
def
add_reference
(
self
):
return
Type
(
self
.
name
+
'&'
)
def
add_const
(
self
):
def
add_const
(
self
)
->
'Type'
:
return
Type
(
'const '
+
self
.
name
)
def
inner_type
(
self
):
def
inner_type
(
self
)
->
Optional
[
'Type'
]
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -55,7 +56,7 @@ class Type:
else
:
return
None
def
remove_generic
(
self
):
def
remove_generic
(
self
)
->
'Type'
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -63,25 +64,25 @@ class Type:
else
:
return
self
def
remove_pointer
(
self
):
def
remove_pointer
(
self
)
->
'Type'
:
if
self
.
is_pointer
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_reference
(
self
):
def
remove_reference
(
self
)
->
'Type'
:
if
self
.
is_reference
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_const
(
self
):
def
remove_const
(
self
)
->
'Type'
:
if
self
.
is_const
():
return
Type
(
self
.
name
[
6
:])
return
self
def
basic
(
self
):
def
basic
(
self
)
->
'Type'
:
return
self
.
remove_pointer
().
remove_const
().
remove_reference
()
def
decay
(
self
):
def
decay
(
self
)
->
'Type'
:
t
=
self
.
remove_reference
()
if
t
.
is_pointer
():
return
t
...
...
@@ -93,7 +94,7 @@ class Type:
return
self
.
add_const
()
return
self
def
str
(
self
):
def
str
(
self
)
->
str
:
return
self
.
name
...
...
@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params})
class
CFunction
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
self
.
params
=
[]
self
.
body
=
[]
self
.
va_start
=
[]
self
.
va_end
=
[]
self
.
params
:
List
[
str
]
=
[]
self
.
body
:
List
[
str
]
=
[]
self
.
va_start
:
List
[
str
]
=
[]
self
.
va_end
:
List
[
str
]
=
[]
def
add_param
(
self
,
type
,
pname
)
:
def
add_param
(
self
,
type
:
str
,
pname
:
str
)
->
None
:
self
.
params
.
append
(
'{} {}'
.
format
(
type
,
pname
))
def
add_statement
(
self
,
stmt
)
:
def
add_statement
(
self
,
stmt
:
str
)
->
None
:
self
.
body
.
append
(
stmt
)
def
add_vlist
(
self
,
name
)
:
def
add_vlist
(
self
,
name
:
str
)
->
None
:
last_param
=
self
.
params
[
-
1
].
split
()[
-
1
]
self
.
va_start
=
[
'va_list {};'
.
format
(
name
),
...
...
@@ -135,7 +136,7 @@ class CFunction:
self
.
va_end
=
[
'va_end({});'
.
format
(
name
)]
self
.
add_param
(
'...'
,
''
)
def
substitute
(
self
,
form
)
:
def
substitute
(
self
,
form
:
Template
)
->
str
:
return
form
.
substitute
(
error_type
=
error_type
,
try_wrap
=
try_wrap
,
name
=
self
.
name
,
...
...
@@ -144,25 +145,29 @@ class CFunction:
va_start
=
"
\n
"
.
join
(
self
.
va_start
),
va_end
=
"
\n
"
.
join
(
self
.
va_end
))
def
generate_header
(
self
):
def
generate_header
(
self
)
->
str
:
return
self
.
substitute
(
header_function
)
def
generate_body
(
self
):
def
generate_body
(
self
)
->
str
:
return
self
.
substitute
(
c_api_impl
)
class
BadParam
:
def
__init__
(
self
,
cond
,
msg
)
:
def
__init__
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
cond
=
cond
self
.
msg
=
msg
class
Parameter
:
def
__init__
(
self
,
name
,
type
,
optional
=
False
,
returns
=
False
):
def
__init__
(
self
,
name
:
str
,
type
:
str
,
optional
:
bool
=
False
,
returns
:
bool
=
False
)
->
None
:
self
.
name
=
name
self
.
type
=
Type
(
type
)
self
.
optional
=
optional
self
.
cparams
=
[]
self
.
cparams
:
List
[
Tuple
[
str
,
str
]]
=
[]
self
.
size_cparam
=
-
1
self
.
size_name
=
''
self
.
read
=
'${name}'
...
...
@@ -170,15 +175,15 @@ class Parameter:
self
.
cpp_read
=
'${name}'
self
.
cpp_write
=
'${name}'
self
.
returns
=
returns
self
.
bad_param_check
=
None
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
def
get_name
(
self
,
prefix
=
None
)
:
def
get_name
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
prefix
:
return
prefix
+
self
.
name
else
:
return
self
.
name
def
get_cpp_type
(
self
):
def
get_cpp_type
(
self
)
->
str
:
if
self
.
type
.
str
()
in
cpp_type_map
:
return
cpp_type_map
[
self
.
type
.
basic
().
str
()]
elif
self
.
type
.
basic
().
str
()
in
cpp_type_map
:
...
...
@@ -188,7 +193,10 @@ class Parameter:
else
:
return
self
.
type
.
str
()
def
substitute
(
self
,
s
,
prefix
=
None
,
result
=
None
):
def
substitute
(
self
,
s
:
str
,
prefix
:
Optional
[
str
]
=
None
,
result
:
Optional
[
str
]
=
None
)
->
str
:
ctype
=
None
if
len
(
self
.
cparams
)
>
0
:
ctype
=
Type
(
self
.
cparams
[
0
][
0
]).
basic
().
str
()
...
...
@@ -199,12 +207,13 @@ class Parameter:
size
=
self
.
size_name
,
result
=
result
or
''
)
def
add_param
(
self
,
t
,
name
=
None
):
def
add_param
(
self
,
t
:
Union
[
str
,
Type
],
name
:
Optional
[
str
]
=
None
)
->
None
:
if
not
isinstance
(
t
,
str
):
t
=
t
.
str
()
self
.
cparams
.
append
((
t
,
name
or
self
.
name
))
def
add_size_param
(
self
,
name
=
None
)
:
def
add_size_param
(
self
,
name
:
Optional
[
str
]
=
None
)
->
None
:
self
.
size_cparam
=
len
(
self
.
cparams
)
self
.
size_name
=
name
or
self
.
name
+
'_size'
if
self
.
returns
:
...
...
@@ -212,7 +221,7 @@ class Parameter:
else
:
self
.
add_param
(
'size_t'
,
self
.
size_name
)
def
bad_param
(
self
,
cond
,
msg
)
:
def
bad_param
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
bad_param_check
=
BadParam
(
cond
,
msg
)
def
remove_size_param
(
self
,
name
):
...
...
@@ -223,7 +232,7 @@ class Parameter:
self
.
size_name
=
name
return
p
def
update
(
self
):
def
update
(
self
)
->
None
:
t
=
self
.
type
.
basic
().
str
()
g
=
self
.
type
.
remove_generic
().
basic
().
str
()
if
t
in
type_map
:
...
...
@@ -239,18 +248,18 @@ class Parameter:
raise
ValueError
(
"Error for {}: write cannot be a string"
.
format
(
self
.
type
.
str
()))
def
cpp_param
(
self
,
prefix
=
None
)
:
def
cpp_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
'${cpptype} ${name}'
,
prefix
=
prefix
)
def
cpp_arg
(
self
,
prefix
=
None
)
:
def
cpp_arg
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_read
,
prefix
=
prefix
)
def
cpp_output_args
(
self
,
prefix
=
None
)
:
def
cpp_output_args
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'&{prefix}{n}'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
output_declarations
(
self
,
prefix
=
None
)
:
def
output_declarations
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'{type} {prefix}{n};'
.
format
(
type
=
Type
(
t
).
remove_pointer
().
str
(),
prefix
=
prefix
,
...
...
@@ -262,16 +271,16 @@ class Parameter:
'&{prefix}{n};'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
cpp_output
(
self
,
prefix
=
None
)
:
def
cpp_output
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_write
,
prefix
=
prefix
)
def
input
(
self
,
prefix
=
None
)
:
def
input
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
'('
+
self
.
substitute
(
self
.
read
,
prefix
=
prefix
)
+
')'
def
outputs
(
self
,
result
=
None
)
:
def
outputs
(
self
,
result
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
self
.
substitute
(
w
,
result
=
result
)
for
w
in
self
.
write
]
def
add_to_cfunction
(
self
,
cfunction
)
:
def
add_to_cfunction
(
self
,
cfunction
:
CFunction
)
->
None
:
for
t
,
name
in
self
.
cparams
:
if
t
.
startswith
(
'...'
):
cfunction
.
add_vlist
(
name
)
...
...
@@ -285,35 +294,35 @@ class Parameter:
body
=
bad_param_error
(
msg
)))
def
template_var
(
s
)
:
def
template_var
(
s
:
str
)
->
str
:
return
'${'
+
s
+
'}'
def
to_template_vars
(
params
)
:
def
to_template_vars
(
params
:
List
[
Union
[
Any
,
Parameter
]])
->
str
:
return
', '
.
join
([
template_var
(
p
.
name
)
for
p
in
params
])
class
Function
:
def
__init__
(
self
,
name
,
params
=
None
,
shared_size
=
False
,
returns
=
None
,
invoke
=
None
,
fname
=
None
,
return_name
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
shared_size
:
bool
=
False
,
returns
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
fname
:
Optional
[
str
]
=
None
,
return_name
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
None
:
self
.
name
=
name
self
.
params
=
params
or
[]
self
.
shared_size
=
False
self
.
cfunction
=
None
self
.
cfunction
:
Optional
[
CFunction
]
=
None
self
.
fname
=
fname
self
.
invoke
=
invoke
or
'${__fname__}($@)'
self
.
return_name
=
return_name
or
'out'
self
.
returns
=
Parameter
(
self
.
return_name
,
returns
,
returns
=
True
)
if
returns
else
None
def
share_params
(
self
):
def
share_params
(
self
)
->
None
:
if
self
.
shared_size
==
True
:
size_param_name
=
'size'
size_type
=
Type
(
'size_t'
)
...
...
@@ -323,7 +332,7 @@ class Function:
size_type
=
Type
(
p
[
0
])
self
.
params
.
append
(
Parameter
(
size_param_name
,
size_type
.
str
()))
def
update
(
self
):
def
update
(
self
)
->
None
:
self
.
share_params
()
for
param
in
self
.
params
:
param
.
update
()
...
...
@@ -331,11 +340,12 @@ class Function:
self
.
returns
.
update
()
self
.
create_cfunction
()
def
inputs
(
self
):
def
inputs
(
self
)
->
str
:
return
', '
.
join
([
p
.
input
()
for
p
in
self
.
params
])
def
input_map
(
self
):
m
=
{}
# TODO: Shoule we remove Optional?
def
input_map
(
self
)
->
Dict
[
str
,
Optional
[
str
]]:
m
:
Dict
[
str
,
Optional
[
str
]]
=
{}
for
p
in
self
.
params
:
m
[
p
.
name
]
=
p
.
input
()
m
[
'return'
]
=
self
.
return_name
...
...
@@ -343,14 +353,22 @@ class Function:
m
[
'__fname__'
]
=
self
.
fname
return
m
def
get_invoke
(
self
):
def
get_invoke
(
self
)
->
str
:
return
Template
(
self
.
invoke
).
safe_substitute
(
self
.
input_map
())
def
write_to_tmp_var
(
self
):
def
write_to_tmp_var
(
self
)
->
bool
:
if
not
self
.
returns
:
return
False
return
len
(
self
.
returns
.
write
)
>
1
or
self
.
returns
.
write
[
0
].
count
(
'${result}'
)
>
1
def
create_cfunction
(
self
):
def
get_cfunction
(
self
)
->
CFunction
:
if
self
.
cfunction
:
return
self
.
cfunction
raise
Exception
(
"self.cfunction is None: self.update() needs to be called."
)
def
create_cfunction
(
self
)
->
None
:
self
.
cfunction
=
CFunction
(
self
.
name
)
# Add the return as a parameter
if
self
.
returns
:
...
...
@@ -358,12 +376,12 @@ class Function:
# Add the input parameters
for
param
in
self
.
params
:
param
.
add_to_cfunction
(
self
.
cfunction
)
f
=
self
.
get_invoke
()
f
:
Optional
[
str
]
=
self
.
get_invoke
()
# Write the assignments
assigns
=
[]
if
self
.
returns
:
result
=
f
if
self
.
write_to_tmp_var
():
if
self
.
write_to_tmp_var
()
and
f
:
f
=
'auto&& api_result = '
+
f
result
=
'api_result'
else
:
...
...
@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template('''
class
CPPMember
:
def
__init__
(
self
,
name
,
function
,
prefix
,
method
=
True
):
def
__init__
(
self
,
name
:
str
,
function
:
Function
,
prefix
:
str
,
method
:
bool
=
True
)
->
None
:
self
.
name
=
name
self
.
function
=
function
self
.
prefix
=
prefix
self
.
method
=
method
def
get_function_params
(
self
):
def
get_function_params
(
self
)
->
List
[
Union
[
Any
,
Parameter
]]
:
if
self
.
method
:
return
self
.
function
.
params
[
1
:]
else
:
return
self
.
function
.
params
def
get_args
(
self
):
def
get_args
(
self
)
->
str
:
output_args
=
[]
if
self
.
function
.
returns
:
output_args
=
self
.
function
.
returns
.
cpp_output_args
(
self
.
prefix
)
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
', '
.
join
(
[
'&{}'
.
format
(
self
.
function
.
cfunction
.
name
)]
+
output_args
+
[
p
.
cpp_arg
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_params
(
self
):
def
get_params
(
self
)
->
str
:
return
', '
.
join
(
[
p
.
cpp_param
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_return_declarations
(
self
):
def
get_return_declarations
(
self
)
->
str
:
if
self
.
function
.
returns
:
return
'
\n
'
.
join
([
d
...
...
@@ -452,7 +476,9 @@ class CPPMember:
def
get_result
(
self
):
return
self
.
function
.
returns
.
input
(
self
.
prefix
)
def
generate_method
(
self
):
def
generate_method
(
self
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
if
self
.
function
.
returns
:
return_type
=
self
.
function
.
returns
.
get_cpp_type
()
return
cpp_class_method_template
.
safe_substitute
(
...
...
@@ -472,7 +498,9 @@ class CPPMember:
args
=
self
.
get_args
(),
success
=
success_type
)
def
generate_constructor
(
self
,
name
):
def
generate_constructor
(
self
,
name
:
str
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
cpp_class_constructor_template
.
safe_substitute
(
name
=
name
,
cfunction
=
self
.
function
.
cfunction
.
name
,
...
...
@@ -482,98 +510,101 @@ class CPPMember:
class
CPPClass
:
def
__init__
(
self
,
name
,
ctype
)
:
def
__init__
(
self
,
name
:
str
,
ctype
:
str
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
constructors
=
[]
self
.
methods
=
[]
self
.
constructors
:
List
[
CPPMember
]
=
[]
self
.
methods
:
List
[
CPPMember
]
=
[]
self
.
prefix
=
'p'
def
add_method
(
self
,
name
,
f
)
:
def
add_method
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
methods
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
add_constructor
(
self
,
name
,
f
)
:
def
add_constructor
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
constructors
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
generate_methods
(
self
):
def
generate_methods
(
self
)
->
str
:
return
'
\n
'
.
join
([
m
.
generate_method
()
for
m
in
self
.
methods
])
def
generate_constructors
(
self
):
def
generate_constructors
(
self
)
->
str
:
return
'
\n
'
.
join
(
[
m
.
generate_constructor
(
self
.
name
)
for
m
in
self
.
constructors
])
def
substitute
(
self
,
s
,
**
kwargs
):
t
=
s
if
isinstance
(
s
,
str
):
t
=
string
.
Template
(
s
)
def
substitute
(
self
,
s
:
Union
[
string
.
Template
,
str
],
**
kwargs
)
->
str
:
t
=
string
.
Template
(
s
)
if
isinstance
(
s
,
str
)
else
s
destroy
=
self
.
ctype
+
'_destroy'
return
t
.
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
destroy
=
destroy
,
**
kwargs
)
def
generate
(
self
):
def
generate
(
self
)
->
str
:
return
self
.
substitute
(
cpp_class_template
,
constructors
=
self
.
substitute
(
self
.
generate_constructors
()),
methods
=
self
.
substitute
(
self
.
generate_methods
()))
def
params
(
virtual
=
None
,
**
kwargs
):
def
params
(
virtual
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
**
kwargs
)
->
List
[
Parameter
]:
result
=
[]
for
name
in
virtual
or
{}:
result
.
append
(
Parameter
(
name
,
virtual
[
name
]))
v
:
Dict
[
str
,
str
]
=
virtual
or
{}
for
name
in
v
:
result
.
append
(
Parameter
(
name
,
v
[
name
]))
for
name
in
kwargs
:
result
.
append
(
Parameter
(
name
,
kwargs
[
name
]))
return
result
def
add_function
(
name
,
*
args
,
**
kwargs
):
def
add_function
(
name
:
str
,
*
args
,
**
kwargs
)
->
Function
:
f
=
Function
(
name
,
*
args
,
**
kwargs
)
functions
.
append
(
f
)
return
f
def
once
(
f
)
:
def
once
(
f
:
Callable
)
->
Any
:
@
wraps
(
f
)
def
decorated
(
*
args
,
**
kwargs
):
if
not
decorated
.
has_run
:
decorated
.
has_run
=
True
return
f
(
*
args
,
**
kwargs
)
decorated
.
has_run
=
False
return
decorated
d
:
Any
=
decorated
d
.
has_run
=
False
return
d
@
once
def
process_functions
():
def
process_functions
()
->
None
:
for
f
in
functions
:
f
.
update
()
def
generate_lines
(
p
)
:
def
generate_lines
(
p
:
List
[
str
])
->
str
:
return
'
\n
'
.
join
(
p
)
def
generate_c_header
():
def
generate_c_header
()
->
str
:
process_functions
()
return
generate_lines
(
c_header_preamble
+
[
f
.
cfunction
.
generate_header
()
for
f
in
functions
])
return
generate_lines
(
c_header_preamble
+
[
f
.
get_cfunction
().
generate_header
()
for
f
in
functions
])
def
generate_c_api_body
():
def
generate_c_api_body
()
->
str
:
process_functions
()
return
generate_lines
(
c_api_body_preamble
+
[
f
.
cfunction
.
generate_body
()
for
f
in
functions
])
return
generate_lines
(
c_api_body_preamble
+
[
f
.
get_cfunction
().
generate_body
()
for
f
in
functions
])
def
generate_cpp_header
():
def
generate_cpp_header
()
->
str
:
process_functions
()
return
generate_lines
(
cpp_header_preamble
+
[
c
.
generate
()
for
c
in
cpp_classes
])
def
cwrap
(
name
)
:
def
cwrap
(
name
:
str
)
->
Callable
:
def
with_cwrap
(
f
):
type_map
[
name
]
=
f
...
...
@@ -677,13 +708,17 @@ protected:
@
once
def
add_handle_preamble
():
def
add_handle_preamble
()
->
None
:
c_api_body_preamble
.
append
(
handle_preamble
)
cpp_header_preamble
.
append
(
string
.
Template
(
cpp_handle_preamble
).
substitute
(
success
=
success_type
))
def
add_handle
(
name
,
ctype
,
cpptype
,
destroy
=
None
,
ref
=
None
):
def
add_handle
(
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
destroy
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
opaque_type
=
ctype
+
'_t'
def
handle_wrap
(
p
):
...
...
@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None):
@
cwrap
(
'std::vector'
)
def
vector_c_wrap
(
p
):
t
=
p
.
type
.
inner_type
().
add_pointer
()
def
vector_c_wrap
(
p
:
Parameter
)
->
None
:
inner
=
p
.
type
.
inner_type
()
# Not a generic type
if
not
inner
:
return
t
=
inner
.
add_pointer
()
if
p
.
returns
:
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_const
():
...
...
@@ -747,7 +786,7 @@ def vector_c_wrap(p):
@
cwrap
(
'std::string'
)
def
string_c_wrap
(
p
)
:
def
string_c_wrap
(
p
:
Parameter
)
->
None
:
t
=
Type
(
'char*'
)
if
p
.
returns
:
if
p
.
type
.
is_reference
():
...
...
@@ -771,7 +810,11 @@ def string_c_wrap(p):
class
Handle
:
def
__init__
(
self
,
name
,
ctype
,
cpptype
,
ref
=
None
):
def
__init__
(
self
,
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
cpptype
=
cpptype
...
...
@@ -779,17 +822,21 @@ class Handle:
add_handle
(
name
,
ctype
,
cpptype
,
ref
=
ref
)
cpp_type_map
[
cpptype
]
=
name
def
cname
(
self
,
name
)
:
def
cname
(
self
,
name
:
str
)
->
str
:
return
self
.
ctype
+
'_'
+
name
def
substitute
(
self
,
s
,
**
kwargs
):
def
substitute
(
self
,
s
:
str
,
**
kwargs
)
->
str
:
return
Template
(
s
).
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
cpptype
=
self
.
cpptype
,
**
kwargs
)
def
constructor
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
**
kwargs
):
def
constructor
(
self
,
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
'Handle'
:
create
=
self
.
substitute
(
'allocate<${cpptype}>($@)'
)
if
fname
:
create
=
self
.
substitute
(
'allocate<${cpptype}>(${fname}($@))'
,
...
...
@@ -805,13 +852,13 @@ class Handle:
return
self
def
method
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
cpp_name
=
None
,
const
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
cpp_name
:
Optional
[
str
]
=
None
,
const
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
'Handle'
:
cpptype
=
self
.
cpptype
if
const
:
cpptype
=
Type
(
cpptype
).
add_const
().
str
()
...
...
@@ -832,11 +879,14 @@ class Handle:
add_function
(
self
.
cname
(
name
),
params
=
params
,
**
kwargs
)
return
self
def
add_cpp_class
(
self
):
def
add_cpp_class
(
self
)
->
None
:
cpp_classes
.
append
(
self
.
cpp_class
)
def
handle
(
ctype
,
cpptype
,
name
=
None
,
ref
=
None
):
def
handle
(
ctype
:
str
,
cpptype
:
str
,
name
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
Callable
:
def
with_handle
(
f
):
n
=
name
or
f
.
__name__
h
=
Handle
(
n
,
ctype
,
cpptype
,
ref
=
ref
)
...
...
@@ -865,10 +915,10 @@ def template_eval(template, **kwargs):
return
template
def
run
(
)
:
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
arg
v
)
>
2
:
f
=
open
(
sys
.
argv
[
2
]).
read
()
def
run
(
args
:
List
[
str
])
->
None
:
runpy
.
run_path
(
args
[
0
])
if
len
(
arg
s
)
>
1
:
f
=
open
(
args
[
1
]).
read
()
r
=
template_eval
(
f
)
sys
.
stdout
.
write
(
r
)
else
:
...
...
@@ -879,4 +929,4 @@ def run():
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
()
run
(
sys
.
argv
[
1
:]
)
tools/generate.sh
View file @
08ac24cf
DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SRC_DIR
=
$DIR
/../src
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"python3.6
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
PYTHON
=
python3
if
type
-p
python3.6
>
/dev/null
;
then
PYTHON
=
python3.6
fi
if
type
-p
python3.8
>
/dev/null
;
then
PYTHON
=
python3.8
fi
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
function
api
{
python3.6
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
}
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment