Commit 0df28887 authored by Paul's avatar Paul
Browse files

Fix bug with incorrect stride calculation:

parent a5b0afa0
...@@ -60,6 +60,10 @@ template <class T> ...@@ -60,6 +60,10 @@ template <class T>
py::buffer_info to_buffer_info(T& x) py::buffer_info to_buffer_info(T& x)
{ {
migraphx::shape s = x.get_shape(); migraphx::shape s = x.get_shape();
auto strides = s.strides();
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) {
return i * s.type_size();
});
py::buffer_info b; py::buffer_info b;
visit_type(s, [&](auto as) { visit_type(s, [&](auto as) {
b = py::buffer_info(x.data(), b = py::buffer_info(x.data(),
...@@ -67,7 +71,7 @@ py::buffer_info to_buffer_info(T& x) ...@@ -67,7 +71,7 @@ py::buffer_info to_buffer_info(T& x)
py::format_descriptor<decltype(as())>::format(), py::format_descriptor<decltype(as())>::format(),
s.lens().size(), s.lens().size(),
s.lens(), s.lens(),
s.strides()); strides);
}); });
return b; return b;
} }
...@@ -75,11 +79,22 @@ py::buffer_info to_buffer_info(T& x) ...@@ -75,11 +79,22 @@ py::buffer_info to_buffer_info(T& x)
migraphx::shape to_shape(const py::buffer_info& info) migraphx::shape to_shape(const py::buffer_info& info)
{ {
migraphx::shape::type_t t; migraphx::shape::type_t t;
std::size_t n = 0;
visit_types([&](auto as) { visit_types([&](auto as) {
if(info.format == py::format_descriptor<decltype(as())>::format()) if(info.format == py::format_descriptor<decltype(as())>::format()) {
t = as.type_enum(); t = as.type_enum();
n = sizeof(as());
}
});
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
if (n > 0)
return n * i;
else
return 0;
}); });
return migraphx::shape{t, info.shape, info.strides}; return migraphx::shape{t, info.shape, strides};
} }
PYBIND11_MODULE(migraphx, m) PYBIND11_MODULE(migraphx, m)
......
...@@ -18,4 +18,5 @@ add_dependencies(check migraphx_py) ...@@ -18,4 +18,5 @@ add_dependencies(check migraphx_py)
add_py_test(cpu cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(cpu cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif() endif()
import migraphx
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r1 = migraphx.from_gpu(p.run(params))
r2 = migraphx.from_gpu(p.run(params))
assert r1 == r2
q1 = memoryview(r1)
q2 = memoryview(r2)
assert q1.tobytes() == q2.tobytes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment