Commit 3032998d authored by Paul's avatar Paul
Browse files

Improve array test

parent 704752eb
......@@ -124,6 +124,11 @@ struct tensor_view
return m_data + this->size();
}
std::vector<T> to_vector() const
{
return std::vector<T>(this->begin(), this->end());
}
friend std::ostream& operator<<(std::ostream& os, const tensor_view<T>& x)
{
if(!x.empty())
......
......@@ -28,6 +28,11 @@ struct throw_half
{
throw std::runtime_error("Half not supported in python yet.");
}
void operator()(migraphx::tensor_view<migraphx::half>) const
{
throw std::runtime_error("Half not supported in python yet.");
}
};
template <class F>
......@@ -42,6 +47,8 @@ struct skip_half
}
void operator()(migraphx::shape::as<migraphx::half>) const {}
void operator()(migraphx::tensor_view<migraphx::half>) const {}
};
template <class F>
......@@ -50,6 +57,12 @@ void visit_type(const migraphx::shape& s, F f)
s.visit_type(throw_half<F>{f});
}
template <class T, class F>
void visit(const migraphx::raw_data<T>& x, F f)
{
x.visit(throw_half<F>{f});
}
template <class F>
void visit_types(F f)
{
......@@ -123,6 +136,14 @@ PYBIND11_MODULE(migraphx, m)
py::buffer_info info = b.request();
new(&x) migraphx::argument(to_shape(info), info.ptr);
})
.def("get_shape", &migraphx::argument::get_shape)
.def("tolist", [](migraphx::argument& x) {
py::list l{x.get_shape().elements()};
visit(x, [&](auto data) {
l = py::cast(data.to_vector());
});
return l;
})
.def("__eq__", std::equal_to<migraphx::argument>{})
.def("__ne__", std::not_equal_to<migraphx::argument>{})
.def("__repr__", [](const migraphx::argument& x) { return migraphx::to_string(x); });
......
import migraphx
import migraphx, struct
def assert_eq(x, y):
if x == y:
pass
else:
raise Exception(str(x) + " != " + str(y))
def get_lens(m):
return list(m.shape)
def get_strides(m):
return [s/m.itemsize for s in m.strides]
def read_float(b, index):
return struct.unpack_from('f', b, index*4)[0]
def check_list(a, b):
l = a.tolist()
for i in range(len(l)):
assert_eq(l[i], read_float(b, i))
def run(p):
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
return migraphx.from_gpu(p.run(params))
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
params = {}
for key, value in p.get_parameter_shapes().items():
params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
r1 = migraphx.from_gpu(p.run(params))
r2 = migraphx.from_gpu(p.run(params))
r1 = run(p)
r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
assert_eq(r1.tolist()[0], read_float(r1, 0))
m1 = memoryview(r1)
m2 = memoryview(r2)
assert_eq(r1.get_shape().elements(), reduce(lambda x,y: x*y,get_lens(m1), 1))
assert_eq(r1.get_shape().lens(), get_lens(m1))
assert_eq(r1.get_shape().strides(), get_strides(m1))
assert r1 == r2
q1 = memoryview(r1)
q2 = memoryview(r2)
assert q1.tobytes() == q2.tobytes()
check_list(r1, m1.tobytes())
check_list(r2, m2.tobytes())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment