Commit 3032998d authored by Paul's avatar Paul
Browse files

Improve array test

parent 704752eb
...@@ -124,6 +124,11 @@ struct tensor_view ...@@ -124,6 +124,11 @@ struct tensor_view
return m_data + this->size(); return m_data + this->size();
} }
std::vector<T> to_vector() const
{
return std::vector<T>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x) friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{ {
if(!x.empty()) if(!x.empty())
......
...@@ -28,6 +28,11 @@ struct throw_half ...@@ -28,6 +28,11 @@ struct throw_half
{ {
throw std::runtime_error("Half not supported in python yet."); throw std::runtime_error("Half not supported in python yet.");
} }
void operator()(migraphx::tensor_view<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
}; };
template <class F> template <class F>
...@@ -42,6 +47,8 @@ struct skip_half ...@@ -42,6 +47,8 @@ struct skip_half
} }
void operator()(migraphx::shape::as<migraphx::half>) const {} void operator()(migraphx::shape::as<migraphx::half>) const {}
void operator()(migraphx::tensor_view<migraphx::half>) const {}
}; };
template <class F> template <class F>
...@@ -50,6 +57,12 @@ void visit_type(const migraphx::shape& s, F f) ...@@ -50,6 +57,12 @@ void visit_type(const migraphx::shape& s, F f)
s.visit_type(throw_half<F>{f}); s.visit_type(throw_half<F>{f});
} }
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
x.visit(throw_half<F>{f});
}
template <class F> template <class F>
void visit_types(F f) void visit_types(F f)
{ {
...@@ -123,6 +136,14 @@ PYBIND11_MODULE(migraphx, m) ...@@ -123,6 +136,14 @@ PYBIND11_MODULE(migraphx, m)
py::buffer_info info = b.request(); py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr); new(&x) migraphx::argument(to_shape(info), info.ptr);
}) })
.def("get_shape", &migraphx::argument::get_shape)
.def("tolist", [](migraphx::argument& x) {
py::list l{x.get_shape().elements()};
visit(x, [&](auto data) {
l = py::cast(data.to_vector());
});
return l;
})
.def("__eq__", std::equal_to<migraphx::argument>{}) .def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{}) .def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); }); .def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
......
import migraphx import migraphx, struct
def assert_eq(x, y):
if x == y:
pass
else:
raise Exception(str(x) + " != " + str(y))
def get_lens(m):
return list(m.shape)
def get_strides(m):
return [s/m.itemsize for s in m.strides]
def read_float(b, index):
return struct.unpack_from('f', b, index*4)[0]
def check_list(a, b):
l = a.tolist()
for i in range(len(l)):
assert_eq(l[i], read_float(b, i))
def run(p):
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
return migraphx.from_gpu(p.run(params))
p = migraphx.parse_onnx("conv_relu_maxpool.onnx") p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu")) p.compile(migraphx.get_target("gpu"))
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r1 = migraphx.from_gpu(p.run(params)) r1 = run(p)
r2 = migraphx.from_gpu(p.run(params)) r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
assert_eq(r1.tolist()[0], read_float(r1, 0))
m1 = memoryview(r1)
m2 = memoryview(r2)
assert_eq(r1.get_shape().elements(), reduce(lambda x,y: x*y,get_lens(m1), 1))
assert_eq(r1.get_shape().lens(), get_lens(m1))
assert_eq(r1.get_shape().strides(), get_strides(m1))
assert r1 == r2 check_list(r1, m1.tobytes())
q1 = memoryview(r1) check_list(r2, m2.tobytes())
q2 = memoryview(r2)
assert q1.tobytes() == q2.tobytes()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment