Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b73b0609
Commit
b73b0609
authored
Feb 19, 2019
by
Paul
Browse files
Calculate strides correctly
parent
27ca76f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
6 deletions
+3
-6
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+1
-4
test/py/test_array.py
test/py/test_array.py
+2
-2
No files found.
src/py/migraphx_py.cpp
View file @
b73b0609
...
...
@@ -102,10 +102,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
});
auto
strides
=
info
.
strides
;
std
::
transform
(
strides
.
begin
(),
strides
.
end
(),
strides
.
begin
(),
[
&
](
auto
i
)
->
std
::
size_t
{
if
(
n
>
0
)
return
n
/
i
;
else
return
0
;
return
n
>
0
?
i
/
n
:
0
;
});
return
migraphx
::
shape
{
t
,
info
.
shape
,
strides
};
}
...
...
test/py/test_array.py
View file @
b73b0609
...
...
@@ -35,7 +35,7 @@ def check_argument(a):
def
check_shapes
(
r
,
m
):
lens
=
list
(
m
.
shape
)
strides
=
[
s
/
m
.
itemsize
for
s
in
m
.
strides
]
strides
=
[
int
(
s
/
m
.
itemsize
)
for
s
in
m
.
strides
]
elements
=
nelements
(
lens
)
assert_eq
(
r
.
get_shape
().
elements
(),
elements
)
assert_eq
(
r
.
get_shape
().
lens
(),
lens
)
...
...
@@ -58,7 +58,7 @@ def test_shape(shape):
def
test_input
():
if
sys
.
version_info
>=
(
3
,
0
):
test_shape
([
4
])
#
test_shape([2, 3])
test_shape
([
2
,
3
])
else
:
data
=
list
(
range
(
4
))
m
=
create_buffer
(
'f'
,
data
,
[
4
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment