Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0df28887
Commit
0df28887
authored
Feb 13, 2019
by
Paul
Browse files
Fix bug with incorrect stride calculation:
parent
a5b0afa0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
3 deletions
+34
-3
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+18
-3
test/py/CMakeLists.txt
test/py/CMakeLists.txt
+1
-0
test/py/array.py
test/py/array.py
+15
-0
No files found.
src/py/migraphx_py.cpp
View file @
0df28887
...
...
@@ -60,6 +60,10 @@ template <class T>
py
::
buffer_info
to_buffer_info
(
T
&
x
)
{
migraphx
::
shape
s
=
x
.
get_shape
();
auto
strides
=
s
.
strides
();
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
{
return
i
*
s
.
type_size
();
});
py
::
buffer_info
b
;
visit_type
(
s
,
[
&
](
auto
as
)
{
b
=
py
::
buffer_info
(
x
.
data
(),
...
...
@@ -67,7 +71,7 @@ py::buffer_info to_buffer_info(T& x)
py
::
format_descriptor
<
decltype
(
as
())
>::
format
(),
s
.
lens
().
size
(),
s
.
lens
(),
s
.
strides
()
);
strides
);
});
return
b
;
}
...
...
@@ -75,11 +79,22 @@ py::buffer_info to_buffer_info(T& x)
migraphx
::
shape
to_shape
(
const
py
::
buffer_info
&
info
)
{
migraphx
::
shape
::
type_t
t
;
std
::
size_t
n
=
0
;
visit_types
([
&
](
auto
as
)
{
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
if
(
info
.
format
==
py
::
format_descriptor
<
decltype
(
as
())
>::
format
())
{
t
=
as
.
type_enum
();
n
=
sizeof
(
as
());
}
});
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
if
(
n
>
0
)
return
n
*
i
;
else
return
0
;
});
return
migraphx
::
shape
{
t
,
info
.
shape
,
info
.
strides
};
return
migraphx
::
shape
{
t
,
info
.
shape
,
strides
};
}
PYBIND11_MODULE
(
migraphx
,
m
)
...
...
test/py/CMakeLists.txt
View file @
0df28887
...
...
@@ -18,4 +18,5 @@ add_dependencies(check migraphx_py)
add_py_test
(
cpu cpu.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
if
(
MIGRAPHX_ENABLE_GPU
)
add_py_test
(
gpu gpu.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
add_py_test
(
array array.py WORKING_DIRECTORY
${
TEST_ONNX_DIR
}
)
endif
()
test/py/array.py
0 → 100644
View file @
0df28887
import
migraphx
p
=
migraphx
.
parse_onnx
(
"conv_relu_maxpool.onnx"
)
p
.
compile
(
migraphx
.
get_target
(
"gpu"
))
params
=
{}
for
key
,
value
in
p
.
get_parameter_shapes
().
items
():
params
[
key
]
=
migraphx
.
to_gpu
(
migraphx
.
generate_argument
(
value
))
r1
=
migraphx
.
from_gpu
(
p
.
run
(
params
))
r2
=
migraphx
.
from_gpu
(
p
.
run
(
params
))
assert
r1
==
r2
q1
=
memoryview
(
r1
)
q2
=
memoryview
(
r2
)
assert
q1
.
tobytes
()
==
q2
.
tobytes
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment