Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4467c158
Unverified
Commit
4467c158
authored
Mar 09, 2022
by
Paul Fultz II
Committed by
GitHub
Mar 09, 2022
Browse files
Add python API to construct shape class (#1128)
Add python API to construct shape class
parent
0e6bd17c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
2 deletions
+33
-2
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+10
-1
test/py/CMakeLists.txt
test/py/CMakeLists.txt
+2
-1
test/py/test_shape.py
test/py/test_shape.py
+21
-0
No files found.
src/py/migraphx_py.cpp
View file @
4467c158
...
...
@@ -211,12 +211,21 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE
(
migraphx
,
m
)
{
py
::
class_
<
migraphx
::
shape
>
(
m
,
"shape"
)
.
def
(
py
::
init
<>
())
.
def
(
py
::
init
([](
py
::
kwargs
kwargs
)
{
auto
v
=
migraphx
::
to_value
(
kwargs
);
auto
t
=
migraphx
::
shape
::
parse_type
(
v
.
get
(
"type"
,
std
::
string
{
"float"
}));
auto
lens
=
v
.
get
<
std
::
size_t
>
(
"lens"
,
{
1
});
if
(
v
.
contains
(
"strides"
))
return
migraphx
::
shape
(
t
,
lens
,
v
.
at
(
"strides"
).
to_vector
<
std
::
size_t
>
());
else
return
migraphx
::
shape
(
t
,
lens
);
}))
.
def
(
"type"
,
&
migraphx
::
shape
::
type
)
.
def
(
"lens"
,
&
migraphx
::
shape
::
lens
)
.
def
(
"strides"
,
&
migraphx
::
shape
::
strides
)
.
def
(
"elements"
,
&
migraphx
::
shape
::
elements
)
.
def
(
"bytes"
,
&
migraphx
::
shape
::
bytes
)
.
def
(
"type_string"
,
&
migraphx
::
shape
::
type_string
)
.
def
(
"type_size"
,
&
migraphx
::
shape
::
type_size
)
.
def
(
"packed"
,
&
migraphx
::
shape
::
packed
)
.
def
(
"transposed"
,
&
migraphx
::
shape
::
transposed
)
...
...
test/py/CMakeLists.txt
View file @
4467c158
...
...
@@ -25,10 +25,11 @@ endforeach()
add_py_test
(
ref test_cpu.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
save_load test_save_load.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
op test_op.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
shape test_shape.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
if
(
MIGRAPHX_ENABLE_GPU
)
add_py_test
(
gpu_offload test_gpu_offload.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
gpu test_gpu.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
array test_array.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
backend onnx_backend_test.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
op test_op.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
endif
()
test/py/test_shape.py
0 → 100644
View file @
4467c158
import
migraphx
def
test_create_shape
():
s
=
migraphx
.
shape
(
lens
=
[
1
,
64
,
3
,
3
])
assert
s
.
standard
()
assert
s
.
packed
()
assert
s
.
lens
()
==
[
1
,
64
,
3
,
3
]
def
test_create_shape_broadcast
():
s
=
migraphx
.
shape
(
lens
=
[
1
,
64
,
3
,
3
],
strides
=
[
0
,
1
,
0
,
0
])
assert
s
.
broadcasted
()
assert
s
.
lens
()
==
[
1
,
64
,
3
,
3
]
assert
s
.
strides
()
==
[
0
,
1
,
0
,
0
]
def
test_create_shape_type
():
s
=
migraphx
.
shape
(
type
=
'uint8'
)
assert
s
.
type_string
()
==
'uint8_type'
assert
s
.
type_size
()
==
1
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment