Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
333cd06e
Commit
333cd06e
authored
Jan 11, 2023
by
charlie
Browse files
Python API partial update
Adds dynamic_dimension object
parent
863bdfbf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
5 deletions
+40
-5
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+31
-5
test/py/test_shape.py
test/py/test_shape.py
+9
-0
No files found.
src/py/migraphx_py.cpp
View file @
333cd06e
...
@@ -236,7 +236,10 @@ migraphx::shape to_shape(const py::buffer_info& info)
...
@@ -236,7 +236,10 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
{
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
py
::
class_
<
migraphx
::
shape
>
py_shape
(
m
,
"shape"
);
// TODO: update this def to also create dynamic shapes
py_shape
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
"float"
));
...
@@ -249,19 +252,30 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -249,19 +252,30 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"ndim"
,
&
migraphx
::
shape
::
ndim
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"dyn_dims"
,
&
migraphx
::
shape
::
dyn_dims
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"broadcasted"
,
&
migraphx
::
shape
::
broadcasted
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"standard"
,
&
migraphx
::
shape
::
standard
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"scalar"
,
&
migraphx
::
shape
::
scalar
)
.
def
(
"dynamic"
,
&
migraphx
::
shape
::
dynamic
)
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__eq__"
,
std
::
equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
shape
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
shape
&
s
)
{
return
migraphx
::
to_string
(
s
);
});
py
::
class_
<
migraphx
::
shape
::
dynamic_dimension
>
(
py_shape
,
"dynamic_dimension"
)
.
def
(
py
::
init
<
std
::
size_t
,
std
::
size_t
,
std
::
size_t
>
())
.
def_readwrite
(
"min"
,
&
migraphx
::
shape
::
dynamic_dimension
::
min
)
.
def_readwrite
(
"max"
,
&
migraphx
::
shape
::
dynamic_dimension
::
max
)
.
def_readwrite
(
"opt"
,
&
migraphx
::
shape
::
dynamic_dimension
::
opt
)
.
def
(
"is_fixed"
,
&
migraphx
::
shape
::
dynamic_dimension
::
is_fixed
)
.
def
(
"has_optimal"
,
&
migraphx
::
shape
::
dynamic_dimension
::
has_optimal
);
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
...
@@ -428,26 +442,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -428,26 +442,38 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
"parse_onnx"
,
"parse_onnx"
,
[](
const
std
::
string
&
filename
,
[](
const
std
::
string
&
filename
,
unsigned
int
default_dim_value
,
unsigned
int
default_dim_value
,
migraphx
::
shape
::
dynamic_dimension
default_dyn_dim_value
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
map_dyn_input_dims
,
bool
skip_unknown_operators
,
bool
skip_unknown_operators
,
bool
print_program_on_error
,
bool
print_program_on_error
,
int64_t
max_loop_iterations
)
{
int64_t
max_loop_iterations
,
bool
use_dyn_output
)
{
migraphx
::
onnx_options
options
;
migraphx
::
onnx_options
options
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dim_value
=
default_dim_value
;
options
.
default_dyn_dim_value
=
default_dyn_dim_value
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_input_dims
=
map_input_dims
;
options
.
map_dyn_input_dims
=
map_dyn_input_dims
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
use_dyn_output
=
use_dyn_output
;
return
migraphx
::
parse_onnx
(
filename
,
options
);
return
migraphx
::
parse_onnx
(
filename
,
options
);
},
},
"Parse onnx file"
,
"Parse onnx file"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"default_dim_value"
)
=
0
,
py
::
arg
(
"default_dyn_dim_value"
)
=
migraphx
::
shape
::
dynamic_dimension
{
1
,
1
},
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_dyn_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
py
::
arg
(
"max_loop_iterations"
)
=
10
,
py
::
arg
(
"use_dyn_output"
)
=
false
);
// TODO: also update reading from ONNX buffer
m
.
def
(
m
.
def
(
"parse_onnx_buffer"
,
"parse_onnx_buffer"
,
[](
const
std
::
string
&
onnx_buffer
,
[](
const
std
::
string
&
onnx_buffer
,
...
...
test/py/test_shape.py
View file @
333cd06e
...
@@ -49,6 +49,15 @@ def test_create_shape_type():
...
@@ -49,6 +49,15 @@ def test_create_shape_type():
assert
s
.
type_size
()
==
4
assert
s
.
type_size
()
==
4
def
test_create_dynamic_dimension
():
dd
=
migraphx
.
shape
.
dynamic_dimension
(
1
,
4
)
assert
dd
.
min
==
1
assert
dd
.
max
==
4
assert
dd
.
opt
==
0
assert
dd
.
is_fixed
==
False
assert
dd
.
has_opt
==
False
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_create_shape
()
test_create_shape
()
test_create_shape_broadcast
()
test_create_shape_broadcast
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment