Commit b73b0609 authored by Paul's avatar Paul
Browse files

Calculate strides correctly

parent 27ca76f4
...@@ -102,10 +102,7 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -102,10 +102,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
}); });
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
if(n > 0) return n > 0 ? i / n : 0;
return n / i;
else
return 0;
}); });
return migraphx::shape{t, info.shape, strides}; return migraphx::shape{t, info.shape, strides};
} }
......
...@@ -35,7 +35,7 @@ def check_argument(a): ...@@ -35,7 +35,7 @@ def check_argument(a):
def check_shapes(r, m): def check_shapes(r, m):
lens = list(m.shape) lens = list(m.shape)
strides = [s/m.itemsize for s in m.strides] strides = [int(s/m.itemsize) for s in m.strides]
elements = nelements(lens) elements = nelements(lens)
assert_eq(r.get_shape().elements(), elements) assert_eq(r.get_shape().elements(), elements)
assert_eq(r.get_shape().lens(), lens) assert_eq(r.get_shape().lens(), lens)
...@@ -58,7 +58,7 @@ def test_shape(shape): ...@@ -58,7 +58,7 @@ def test_shape(shape):
def test_input(): def test_input():
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
test_shape([4]) test_shape([4])
# test_shape([2, 3]) test_shape([2, 3])
else: else:
data = list(range(4)) data = list(range(4))
m = create_buffer('f', data, [4]) m = create_buffer('f', data, [4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment