Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2a79a9ff
Unverified
Commit
2a79a9ff
authored
Feb 01, 2022
by
Paul Fultz II
Committed by
GitHub
Feb 01, 2022
Browse files
Add python type annotations to api.py (#1061)
This will also check the types using mypy on the CI.
parent
7e7ef0b8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
195 additions
and
136 deletions
+195
-136
.github/workflows/ci.yaml
.github/workflows/ci.yaml
+4
-2
tools/api.py
tools/api.py
+182
-132
tools/generate.sh
tools/generate.sh
+9
-2
No files found.
.github/workflows/ci.yaml
View file @
2a79a9ff
...
...
@@ -142,10 +142,12 @@ jobs:
with
:
python-version
:
3.6
-
name
:
Install pyflakes
run
:
pip install pyflakes==2.3.1
run
:
pip install pyflakes==2.3.1
mypy==0.931
-
name
:
Run pyflakes
run
:
pyflakes examples/ tools/ src/ test/ doc/
run
:
|
pyflakes examples/ tools/ src/ test/ doc/
mypy tools/api.py
linux
:
...
...
tools/api.py
View file @
2a79a9ff
import
string
,
sys
,
re
,
runpy
from
functools
import
wraps
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
type_map
=
{}
cpp_type_map
=
{}
functions
=
[]
cpp_classes
=
[]
type_map
:
Dict
[
str
,
Callable
[[
'Parameter'
],
None
]]
=
{}
cpp_type_map
:
Dict
[
str
,
str
]
=
{}
functions
:
List
[
'Function'
]
=
[]
cpp_classes
:
List
[
'CPPClass'
]
=
[]
error_type
=
''
success_type
=
''
try_wrap
=
''
c_header_preamble
=
[]
c_api_body_preamble
=
[]
cpp_header_preamble
=
[]
c_header_preamble
:
List
[
str
]
=
[]
c_api_body_preamble
:
List
[
str
]
=
[]
cpp_header_preamble
:
List
[
str
]
=
[]
def
bad_param_error
(
msg
):
...
...
@@ -23,31 +24,31 @@ class Template(string.Template):
class
Type
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
.
strip
()
def
is_pointer
(
self
):
def
is_pointer
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'*'
)
def
is_reference
(
self
):
def
is_reference
(
self
)
->
bool
:
return
self
.
name
.
endswith
(
'&'
)
def
is_const
(
self
):
def
is_const
(
self
)
->
bool
:
return
self
.
name
.
startswith
(
'const '
)
def
is_variadic
(
self
):
return
self
.
name
.
startswith
(
'...'
)
def
add_pointer
(
self
):
def
add_pointer
(
self
)
->
'Type'
:
return
Type
(
self
.
name
+
'*'
)
def
add_reference
(
self
):
return
Type
(
self
.
name
+
'&'
)
def
add_const
(
self
):
def
add_const
(
self
)
->
'Type'
:
return
Type
(
'const '
+
self
.
name
)
def
inner_type
(
self
):
def
inner_type
(
self
)
->
Optional
[
'Type'
]
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -55,7 +56,7 @@ class Type:
else
:
return
None
def
remove_generic
(
self
):
def
remove_generic
(
self
)
->
'Type'
:
i
=
self
.
name
.
find
(
'<'
)
j
=
self
.
name
.
rfind
(
'>'
)
if
i
>
0
and
j
>
0
:
...
...
@@ -63,25 +64,25 @@ class Type:
else
:
return
self
def
remove_pointer
(
self
):
def
remove_pointer
(
self
)
->
'Type'
:
if
self
.
is_pointer
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_reference
(
self
):
def
remove_reference
(
self
)
->
'Type'
:
if
self
.
is_reference
():
return
Type
(
self
.
name
[
0
:
-
1
])
return
self
def
remove_const
(
self
):
def
remove_const
(
self
)
->
'Type'
:
if
self
.
is_const
():
return
Type
(
self
.
name
[
6
:])
return
self
def
basic
(
self
):
def
basic
(
self
)
->
'Type'
:
return
self
.
remove_pointer
().
remove_const
().
remove_reference
()
def
decay
(
self
):
def
decay
(
self
)
->
'Type'
:
t
=
self
.
remove_reference
()
if
t
.
is_pointer
():
return
t
...
...
@@ -93,7 +94,7 @@ class Type:
return
self
.
add_const
()
return
self
def
str
(
self
):
def
str
(
self
)
->
str
:
return
self
.
name
...
...
@@ -113,20 +114,20 @@ extern "C" ${error_type} ${name}(${params})
class
CFunction
:
def
__init__
(
self
,
name
)
:
def
__init__
(
self
,
name
:
str
)
->
None
:
self
.
name
=
name
self
.
params
=
[]
self
.
body
=
[]
self
.
va_start
=
[]
self
.
va_end
=
[]
self
.
params
:
List
[
str
]
=
[]
self
.
body
:
List
[
str
]
=
[]
self
.
va_start
:
List
[
str
]
=
[]
self
.
va_end
:
List
[
str
]
=
[]
def
add_param
(
self
,
type
,
pname
)
:
def
add_param
(
self
,
type
:
str
,
pname
:
str
)
->
None
:
self
.
params
.
append
(
'{} {}'
.
format
(
type
,
pname
))
def
add_statement
(
self
,
stmt
)
:
def
add_statement
(
self
,
stmt
:
str
)
->
None
:
self
.
body
.
append
(
stmt
)
def
add_vlist
(
self
,
name
)
:
def
add_vlist
(
self
,
name
:
str
)
->
None
:
last_param
=
self
.
params
[
-
1
].
split
()[
-
1
]
self
.
va_start
=
[
'va_list {};'
.
format
(
name
),
...
...
@@ -135,7 +136,7 @@ class CFunction:
self
.
va_end
=
[
'va_end({});'
.
format
(
name
)]
self
.
add_param
(
'...'
,
''
)
def
substitute
(
self
,
form
)
:
def
substitute
(
self
,
form
:
Template
)
->
str
:
return
form
.
substitute
(
error_type
=
error_type
,
try_wrap
=
try_wrap
,
name
=
self
.
name
,
...
...
@@ -144,25 +145,29 @@ class CFunction:
va_start
=
"
\n
"
.
join
(
self
.
va_start
),
va_end
=
"
\n
"
.
join
(
self
.
va_end
))
def
generate_header
(
self
):
def
generate_header
(
self
)
->
str
:
return
self
.
substitute
(
header_function
)
def
generate_body
(
self
):
def
generate_body
(
self
)
->
str
:
return
self
.
substitute
(
c_api_impl
)
class
BadParam
:
def
__init__
(
self
,
cond
,
msg
)
:
def
__init__
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
cond
=
cond
self
.
msg
=
msg
class
Parameter
:
def
__init__
(
self
,
name
,
type
,
optional
=
False
,
returns
=
False
):
def
__init__
(
self
,
name
:
str
,
type
:
str
,
optional
:
bool
=
False
,
returns
:
bool
=
False
)
->
None
:
self
.
name
=
name
self
.
type
=
Type
(
type
)
self
.
optional
=
optional
self
.
cparams
=
[]
self
.
cparams
:
List
[
Tuple
[
str
,
str
]]
=
[]
self
.
size_cparam
=
-
1
self
.
size_name
=
''
self
.
read
=
'${name}'
...
...
@@ -170,15 +175,15 @@ class Parameter:
self
.
cpp_read
=
'${name}'
self
.
cpp_write
=
'${name}'
self
.
returns
=
returns
self
.
bad_param_check
=
None
self
.
bad_param_check
:
Optional
[
BadParam
]
=
None
def
get_name
(
self
,
prefix
=
None
)
:
def
get_name
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
if
prefix
:
return
prefix
+
self
.
name
else
:
return
self
.
name
def
get_cpp_type
(
self
):
def
get_cpp_type
(
self
)
->
str
:
if
self
.
type
.
str
()
in
cpp_type_map
:
return
cpp_type_map
[
self
.
type
.
basic
().
str
()]
elif
self
.
type
.
basic
().
str
()
in
cpp_type_map
:
...
...
@@ -188,7 +193,10 @@ class Parameter:
else
:
return
self
.
type
.
str
()
def
substitute
(
self
,
s
,
prefix
=
None
,
result
=
None
):
def
substitute
(
self
,
s
:
str
,
prefix
:
Optional
[
str
]
=
None
,
result
:
Optional
[
str
]
=
None
)
->
str
:
ctype
=
None
if
len
(
self
.
cparams
)
>
0
:
ctype
=
Type
(
self
.
cparams
[
0
][
0
]).
basic
().
str
()
...
...
@@ -199,12 +207,13 @@ class Parameter:
size
=
self
.
size_name
,
result
=
result
or
''
)
def
add_param
(
self
,
t
,
name
=
None
):
def
add_param
(
self
,
t
:
Union
[
str
,
Type
],
name
:
Optional
[
str
]
=
None
)
->
None
:
if
not
isinstance
(
t
,
str
):
t
=
t
.
str
()
self
.
cparams
.
append
((
t
,
name
or
self
.
name
))
def
add_size_param
(
self
,
name
=
None
)
:
def
add_size_param
(
self
,
name
:
Optional
[
str
]
=
None
)
->
None
:
self
.
size_cparam
=
len
(
self
.
cparams
)
self
.
size_name
=
name
or
self
.
name
+
'_size'
if
self
.
returns
:
...
...
@@ -212,7 +221,7 @@ class Parameter:
else
:
self
.
add_param
(
'size_t'
,
self
.
size_name
)
def
bad_param
(
self
,
cond
,
msg
)
:
def
bad_param
(
self
,
cond
:
str
,
msg
:
str
)
->
None
:
self
.
bad_param_check
=
BadParam
(
cond
,
msg
)
def
remove_size_param
(
self
,
name
):
...
...
@@ -223,7 +232,7 @@ class Parameter:
self
.
size_name
=
name
return
p
def
update
(
self
):
def
update
(
self
)
->
None
:
t
=
self
.
type
.
basic
().
str
()
g
=
self
.
type
.
remove_generic
().
basic
().
str
()
if
t
in
type_map
:
...
...
@@ -239,18 +248,18 @@ class Parameter:
raise
ValueError
(
"Error for {}: write cannot be a string"
.
format
(
self
.
type
.
str
()))
def
cpp_param
(
self
,
prefix
=
None
)
:
def
cpp_param
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
'${cpptype} ${name}'
,
prefix
=
prefix
)
def
cpp_arg
(
self
,
prefix
=
None
)
:
def
cpp_arg
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_read
,
prefix
=
prefix
)
def
cpp_output_args
(
self
,
prefix
=
None
)
:
def
cpp_output_args
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'&{prefix}{n}'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
output_declarations
(
self
,
prefix
=
None
)
:
def
output_declarations
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
'{type} {prefix}{n};'
.
format
(
type
=
Type
(
t
).
remove_pointer
().
str
(),
prefix
=
prefix
,
...
...
@@ -262,16 +271,16 @@ class Parameter:
'&{prefix}{n};'
.
format
(
prefix
=
prefix
,
n
=
n
)
for
t
,
n
in
self
.
cparams
]
def
cpp_output
(
self
,
prefix
=
None
)
:
def
cpp_output
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
self
.
substitute
(
self
.
cpp_write
,
prefix
=
prefix
)
def
input
(
self
,
prefix
=
None
)
:
def
input
(
self
,
prefix
:
Optional
[
str
]
=
None
)
->
str
:
return
'('
+
self
.
substitute
(
self
.
read
,
prefix
=
prefix
)
+
')'
def
outputs
(
self
,
result
=
None
)
:
def
outputs
(
self
,
result
:
Optional
[
str
]
=
None
)
->
List
[
str
]
:
return
[
self
.
substitute
(
w
,
result
=
result
)
for
w
in
self
.
write
]
def
add_to_cfunction
(
self
,
cfunction
)
:
def
add_to_cfunction
(
self
,
cfunction
:
CFunction
)
->
None
:
for
t
,
name
in
self
.
cparams
:
if
t
.
startswith
(
'...'
):
cfunction
.
add_vlist
(
name
)
...
...
@@ -285,35 +294,35 @@ class Parameter:
body
=
bad_param_error
(
msg
)))
def
template_var
(
s
)
:
def
template_var
(
s
:
str
)
->
str
:
return
'${'
+
s
+
'}'
def
to_template_vars
(
params
)
:
def
to_template_vars
(
params
:
List
[
Union
[
Any
,
Parameter
]])
->
str
:
return
', '
.
join
([
template_var
(
p
.
name
)
for
p
in
params
])
class
Function
:
def
__init__
(
self
,
name
,
params
=
None
,
shared_size
=
False
,
returns
=
None
,
invoke
=
None
,
fname
=
None
,
return_name
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
shared_size
:
bool
=
False
,
returns
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
fname
:
Optional
[
str
]
=
None
,
return_name
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
None
:
self
.
name
=
name
self
.
params
=
params
or
[]
self
.
shared_size
=
False
self
.
cfunction
=
None
self
.
cfunction
:
Optional
[
CFunction
]
=
None
self
.
fname
=
fname
self
.
invoke
=
invoke
or
'${__fname__}($@)'
self
.
return_name
=
return_name
or
'out'
self
.
returns
=
Parameter
(
self
.
return_name
,
returns
,
returns
=
True
)
if
returns
else
None
def
share_params
(
self
):
def
share_params
(
self
)
->
None
:
if
self
.
shared_size
==
True
:
size_param_name
=
'size'
size_type
=
Type
(
'size_t'
)
...
...
@@ -323,7 +332,7 @@ class Function:
size_type
=
Type
(
p
[
0
])
self
.
params
.
append
(
Parameter
(
size_param_name
,
size_type
.
str
()))
def
update
(
self
):
def
update
(
self
)
->
None
:
self
.
share_params
()
for
param
in
self
.
params
:
param
.
update
()
...
...
@@ -331,11 +340,12 @@ class Function:
self
.
returns
.
update
()
self
.
create_cfunction
()
def
inputs
(
self
):
def
inputs
(
self
)
->
str
:
return
', '
.
join
([
p
.
input
()
for
p
in
self
.
params
])
def
input_map
(
self
):
m
=
{}
# TODO: Shoule we remove Optional?
def
input_map
(
self
)
->
Dict
[
str
,
Optional
[
str
]]:
m
:
Dict
[
str
,
Optional
[
str
]]
=
{}
for
p
in
self
.
params
:
m
[
p
.
name
]
=
p
.
input
()
m
[
'return'
]
=
self
.
return_name
...
...
@@ -343,14 +353,22 @@ class Function:
m
[
'__fname__'
]
=
self
.
fname
return
m
def
get_invoke
(
self
):
def
get_invoke
(
self
)
->
str
:
return
Template
(
self
.
invoke
).
safe_substitute
(
self
.
input_map
())
def
write_to_tmp_var
(
self
):
def
write_to_tmp_var
(
self
)
->
bool
:
if
not
self
.
returns
:
return
False
return
len
(
self
.
returns
.
write
)
>
1
or
self
.
returns
.
write
[
0
].
count
(
'${result}'
)
>
1
def
create_cfunction
(
self
):
def
get_cfunction
(
self
)
->
CFunction
:
if
self
.
cfunction
:
return
self
.
cfunction
raise
Exception
(
"self.cfunction is None: self.update() needs to be called."
)
def
create_cfunction
(
self
)
->
None
:
self
.
cfunction
=
CFunction
(
self
.
name
)
# Add the return as a parameter
if
self
.
returns
:
...
...
@@ -358,12 +376,12 @@ class Function:
# Add the input parameters
for
param
in
self
.
params
:
param
.
add_to_cfunction
(
self
.
cfunction
)
f
=
self
.
get_invoke
()
f
:
Optional
[
str
]
=
self
.
get_invoke
()
# Write the assignments
assigns
=
[]
if
self
.
returns
:
result
=
f
if
self
.
write_to_tmp_var
():
if
self
.
write_to_tmp_var
()
and
f
:
f
=
'auto&& api_result = '
+
f
result
=
'api_result'
else
:
...
...
@@ -416,31 +434,37 @@ cpp_class_constructor_template = Template('''
class
CPPMember
:
def
__init__
(
self
,
name
,
function
,
prefix
,
method
=
True
):
def
__init__
(
self
,
name
:
str
,
function
:
Function
,
prefix
:
str
,
method
:
bool
=
True
)
->
None
:
self
.
name
=
name
self
.
function
=
function
self
.
prefix
=
prefix
self
.
method
=
method
def
get_function_params
(
self
):
def
get_function_params
(
self
)
->
List
[
Union
[
Any
,
Parameter
]]
:
if
self
.
method
:
return
self
.
function
.
params
[
1
:]
else
:
return
self
.
function
.
params
def
get_args
(
self
):
def
get_args
(
self
)
->
str
:
output_args
=
[]
if
self
.
function
.
returns
:
output_args
=
self
.
function
.
returns
.
cpp_output_args
(
self
.
prefix
)
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
', '
.
join
(
[
'&{}'
.
format
(
self
.
function
.
cfunction
.
name
)]
+
output_args
+
[
p
.
cpp_arg
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_params
(
self
):
def
get_params
(
self
)
->
str
:
return
', '
.
join
(
[
p
.
cpp_param
(
self
.
prefix
)
for
p
in
self
.
get_function_params
()])
def
get_return_declarations
(
self
):
def
get_return_declarations
(
self
)
->
str
:
if
self
.
function
.
returns
:
return
'
\n
'
.
join
([
d
...
...
@@ -452,7 +476,9 @@ class CPPMember:
def
get_result
(
self
):
return
self
.
function
.
returns
.
input
(
self
.
prefix
)
def
generate_method
(
self
):
def
generate_method
(
self
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
if
self
.
function
.
returns
:
return_type
=
self
.
function
.
returns
.
get_cpp_type
()
return
cpp_class_method_template
.
safe_substitute
(
...
...
@@ -472,7 +498,9 @@ class CPPMember:
args
=
self
.
get_args
(),
success
=
success_type
)
def
generate_constructor
(
self
,
name
):
def
generate_constructor
(
self
,
name
:
str
)
->
str
:
if
not
self
.
function
.
cfunction
:
raise
Exception
(
'self.function.update() must be called'
)
return
cpp_class_constructor_template
.
safe_substitute
(
name
=
name
,
cfunction
=
self
.
function
.
cfunction
.
name
,
...
...
@@ -482,98 +510,101 @@ class CPPMember:
class
CPPClass
:
def
__init__
(
self
,
name
,
ctype
)
:
def
__init__
(
self
,
name
:
str
,
ctype
:
str
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
constructors
=
[]
self
.
methods
=
[]
self
.
constructors
:
List
[
CPPMember
]
=
[]
self
.
methods
:
List
[
CPPMember
]
=
[]
self
.
prefix
=
'p'
def
add_method
(
self
,
name
,
f
)
:
def
add_method
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
methods
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
add_constructor
(
self
,
name
,
f
)
:
def
add_constructor
(
self
,
name
:
str
,
f
:
Function
)
->
None
:
self
.
constructors
.
append
(
CPPMember
(
name
,
f
,
self
.
prefix
,
method
=
True
))
def
generate_methods
(
self
):
def
generate_methods
(
self
)
->
str
:
return
'
\n
'
.
join
([
m
.
generate_method
()
for
m
in
self
.
methods
])
def
generate_constructors
(
self
):
def
generate_constructors
(
self
)
->
str
:
return
'
\n
'
.
join
(
[
m
.
generate_constructor
(
self
.
name
)
for
m
in
self
.
constructors
])
def
substitute
(
self
,
s
,
**
kwargs
):
t
=
s
if
isinstance
(
s
,
str
):
t
=
string
.
Template
(
s
)
def
substitute
(
self
,
s
:
Union
[
string
.
Template
,
str
],
**
kwargs
)
->
str
:
t
=
string
.
Template
(
s
)
if
isinstance
(
s
,
str
)
else
s
destroy
=
self
.
ctype
+
'_destroy'
return
t
.
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
destroy
=
destroy
,
**
kwargs
)
def
generate
(
self
):
def
generate
(
self
)
->
str
:
return
self
.
substitute
(
cpp_class_template
,
constructors
=
self
.
substitute
(
self
.
generate_constructors
()),
methods
=
self
.
substitute
(
self
.
generate_methods
()))
def
params
(
virtual
=
None
,
**
kwargs
):
def
params
(
virtual
:
Optional
[
Dict
[
str
,
str
]]
=
None
,
**
kwargs
)
->
List
[
Parameter
]:
result
=
[]
for
name
in
virtual
or
{}:
result
.
append
(
Parameter
(
name
,
virtual
[
name
]))
v
:
Dict
[
str
,
str
]
=
virtual
or
{}
for
name
in
v
:
result
.
append
(
Parameter
(
name
,
v
[
name
]))
for
name
in
kwargs
:
result
.
append
(
Parameter
(
name
,
kwargs
[
name
]))
return
result
def
add_function
(
name
,
*
args
,
**
kwargs
):
def
add_function
(
name
:
str
,
*
args
,
**
kwargs
)
->
Function
:
f
=
Function
(
name
,
*
args
,
**
kwargs
)
functions
.
append
(
f
)
return
f
def
once
(
f
)
:
def
once
(
f
:
Callable
)
->
Any
:
@
wraps
(
f
)
def
decorated
(
*
args
,
**
kwargs
):
if
not
decorated
.
has_run
:
decorated
.
has_run
=
True
return
f
(
*
args
,
**
kwargs
)
decorated
.
has_run
=
False
return
decorated
d
:
Any
=
decorated
d
.
has_run
=
False
return
d
@
once
def
process_functions
():
def
process_functions
()
->
None
:
for
f
in
functions
:
f
.
update
()
def
generate_lines
(
p
)
:
def
generate_lines
(
p
:
List
[
str
])
->
str
:
return
'
\n
'
.
join
(
p
)
def
generate_c_header
():
def
generate_c_header
()
->
str
:
process_functions
()
return
generate_lines
(
c_header_preamble
+
[
f
.
cfunction
.
generate_header
()
for
f
in
functions
])
return
generate_lines
(
c_header_preamble
+
[
f
.
get_cfunction
().
generate_header
()
for
f
in
functions
])
def
generate_c_api_body
():
def
generate_c_api_body
()
->
str
:
process_functions
()
return
generate_lines
(
c_api_body_preamble
+
[
f
.
cfunction
.
generate_body
()
for
f
in
functions
])
return
generate_lines
(
c_api_body_preamble
+
[
f
.
get_cfunction
().
generate_body
()
for
f
in
functions
])
def
generate_cpp_header
():
def
generate_cpp_header
()
->
str
:
process_functions
()
return
generate_lines
(
cpp_header_preamble
+
[
c
.
generate
()
for
c
in
cpp_classes
])
def
cwrap
(
name
)
:
def
cwrap
(
name
:
str
)
->
Callable
:
def
with_cwrap
(
f
):
type_map
[
name
]
=
f
...
...
@@ -677,13 +708,17 @@ protected:
@
once
def
add_handle_preamble
():
def
add_handle_preamble
()
->
None
:
c_api_body_preamble
.
append
(
handle_preamble
)
cpp_header_preamble
.
append
(
string
.
Template
(
cpp_handle_preamble
).
substitute
(
success
=
success_type
))
def
add_handle
(
name
,
ctype
,
cpptype
,
destroy
=
None
,
ref
=
None
):
def
add_handle
(
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
destroy
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
opaque_type
=
ctype
+
'_t'
def
handle_wrap
(
p
):
...
...
@@ -718,8 +753,12 @@ def add_handle(name, ctype, cpptype, destroy=None, ref=None):
@
cwrap
(
'std::vector'
)
def
vector_c_wrap
(
p
):
t
=
p
.
type
.
inner_type
().
add_pointer
()
def
vector_c_wrap
(
p
:
Parameter
)
->
None
:
inner
=
p
.
type
.
inner_type
()
# Not a generic type
if
not
inner
:
return
t
=
inner
.
add_pointer
()
if
p
.
returns
:
if
p
.
type
.
is_reference
():
if
p
.
type
.
is_const
():
...
...
@@ -747,7 +786,7 @@ def vector_c_wrap(p):
@
cwrap
(
'std::string'
)
def
string_c_wrap
(
p
)
:
def
string_c_wrap
(
p
:
Parameter
)
->
None
:
t
=
Type
(
'char*'
)
if
p
.
returns
:
if
p
.
type
.
is_reference
():
...
...
@@ -771,7 +810,11 @@ def string_c_wrap(p):
class
Handle
:
def
__init__
(
self
,
name
,
ctype
,
cpptype
,
ref
=
None
):
def
__init__
(
self
,
name
:
str
,
ctype
:
str
,
cpptype
:
str
,
ref
:
Optional
[
bool
]
=
None
)
->
None
:
self
.
name
=
name
self
.
ctype
=
ctype
self
.
cpptype
=
cpptype
...
...
@@ -779,17 +822,21 @@ class Handle:
add_handle
(
name
,
ctype
,
cpptype
,
ref
=
ref
)
cpp_type_map
[
cpptype
]
=
name
def
cname
(
self
,
name
)
:
def
cname
(
self
,
name
:
str
)
->
str
:
return
self
.
ctype
+
'_'
+
name
def
substitute
(
self
,
s
,
**
kwargs
):
def
substitute
(
self
,
s
:
str
,
**
kwargs
)
->
str
:
return
Template
(
s
).
safe_substitute
(
name
=
self
.
name
,
ctype
=
self
.
ctype
,
cpptype
=
self
.
cpptype
,
**
kwargs
)
def
constructor
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
**
kwargs
):
def
constructor
(
self
,
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
**
kwargs
)
->
'Handle'
:
create
=
self
.
substitute
(
'allocate<${cpptype}>($@)'
)
if
fname
:
create
=
self
.
substitute
(
'allocate<${cpptype}>(${fname}($@))'
,
...
...
@@ -805,13 +852,13 @@ class Handle:
return
self
def
method
(
self
,
name
,
params
=
None
,
fname
=
None
,
invoke
=
None
,
cpp_name
=
None
,
const
=
None
,
**
kwargs
):
name
:
str
,
params
:
Optional
[
List
[
Parameter
]]
=
None
,
fname
:
Optional
[
str
]
=
None
,
invoke
:
Optional
[
str
]
=
None
,
cpp_name
:
Optional
[
str
]
=
None
,
const
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
'Handle'
:
cpptype
=
self
.
cpptype
if
const
:
cpptype
=
Type
(
cpptype
).
add_const
().
str
()
...
...
@@ -832,11 +879,14 @@ class Handle:
add_function
(
self
.
cname
(
name
),
params
=
params
,
**
kwargs
)
return
self
def
add_cpp_class
(
self
):
def
add_cpp_class
(
self
)
->
None
:
cpp_classes
.
append
(
self
.
cpp_class
)
def
handle
(
ctype
,
cpptype
,
name
=
None
,
ref
=
None
):
def
handle
(
ctype
:
str
,
cpptype
:
str
,
name
:
Optional
[
str
]
=
None
,
ref
:
Optional
[
bool
]
=
None
)
->
Callable
:
def
with_handle
(
f
):
n
=
name
or
f
.
__name__
h
=
Handle
(
n
,
ctype
,
cpptype
,
ref
=
ref
)
...
...
@@ -865,10 +915,10 @@ def template_eval(template, **kwargs):
return
template
def
run
(
)
:
runpy
.
run_path
(
sys
.
argv
[
1
])
if
len
(
sys
.
arg
v
)
>
2
:
f
=
open
(
sys
.
argv
[
2
]).
read
()
def
run
(
args
:
List
[
str
])
->
None
:
runpy
.
run_path
(
args
[
0
])
if
len
(
arg
s
)
>
1
:
f
=
open
(
args
[
1
]).
read
()
r
=
template_eval
(
f
)
sys
.
stdout
.
write
(
r
)
else
:
...
...
@@ -879,4 +929,4 @@ def run():
if
__name__
==
"__main__"
:
sys
.
modules
[
'api'
]
=
sys
.
modules
[
'__main__'
]
run
()
run
(
sys
.
argv
[
1
:]
)
tools/generate.sh
View file @
2a79a9ff
DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
SRC_DIR
=
$DIR
/../src
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"python3.6
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
PYTHON
=
python3
if
type
-p
python3.6
>
/dev/null
;
then
PYTHON
=
python3.6
fi
if
type
-p
python3.8
>
/dev/null
;
then
PYTHON
=
python3.8
fi
ls
-1
$DIR
/include/ | xargs
-n
1
-P
$(
nproc
)
-I
{}
-t
bash
-c
"
$PYTHON
$DIR
/te.py
$DIR
/include/{} | clang-format-5.0 -style=file >
$SRC_DIR
/include/migraphx/{}"
function
api
{
python3.6
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
$PYTHON
$DIR
/api.py
$SRC_DIR
/api/migraphx.py
$1
| clang-format-5.0
-style
=
file
>
$2
}
api
$DIR
/api/migraphx.h
$SRC_DIR
/api/include/migraphx/migraphx.h
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment