Commit b73b0609 authored by Paul's avatar Paul
Browse files

Calculate strides correctly

parent 27ca76f4
......@@ -102,10 +102,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
});
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
if(n > 0)
return n / i;
else
return 0;
return n > 0 ? i / n : 0;
});
return migraphx::shape{t, info.shape, strides};
}
......
......@@ -35,7 +35,7 @@ def check_argument(a):
def check_shapes(r, m):
lens = list(m.shape)
strides = [s/m.itemsize for s in m.strides]
strides = [int(s/m.itemsize) for s in m.strides]
elements = nelements(lens)
assert_eq(r.get_shape().elements(), elements)
assert_eq(r.get_shape().lens(), lens)
......@@ -58,7 +58,7 @@ def test_shape(shape):
def test_input():
if sys.version_info >= (3, 0):
test_shape([4])
# test_shape([2, 3])
test_shape([2, 3])
else:
data = list(range(4))
m = create_buffer('f', data, [4])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment