Commit e141854c authored by Paul's avatar Paul
Browse files

Fix bug when reading from buffer

parent 6bc81428
...@@ -103,7 +103,7 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -103,7 +103,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
auto strides = info.strides; auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t { std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
if(n > 0) if(n > 0)
return n * i; return n / i;
else else
return 0; return 0;
}); });
......
import migraphx, struct import migraphx, struct, array, sys
try:
from functools import reduce
except:
pass
def assert_eq(x, y): def assert_eq(x, y):
if x == y: if x == y:
...@@ -6,20 +10,37 @@ def assert_eq(x, y): ...@@ -6,20 +10,37 @@ def assert_eq(x, y):
else: else:
raise Exception(str(x) + " != " + str(y)) raise Exception(str(x) + " != " + str(y))
def get_lens(m):
return list(m.shape)
def get_strides(m):
return [s/m.itemsize for s in m.strides]
def read_float(b, index): def read_float(b, index):
return struct.unpack_from('f', b, index*4)[0] return struct.unpack_from('f', b, index*4)[0]
def check_list(a): def write_float(b, index):
struct.pack_into('f', b, index*4)
def nelements(lens):
return reduce(lambda x,y: x*y,lens, 1)
def create_buffer(t, data):
a = array.array(t, data)
if sys.version_info >= (3, 0):
m = memoryview(a.tobytes())
return m.cast(t)
else:
m = memoryview(a.tostring())
return m
def check_argument(a):
l = a.tolist() l = a.tolist()
for i in range(len(l)): for i in range(len(l)):
assert_eq(l[i], read_float(a, i)) assert_eq(l[i], read_float(a, i))
def check_shapes(r, m):
lens = list(m.shape)
strides = [s/m.itemsize for s in m.strides]
elements = nelements(lens)
assert_eq(r.get_shape().elements(), elements)
assert_eq(r.get_shape().lens(), lens)
assert_eq(r.get_shape().strides(), strides)
def run(p): def run(p):
params = {} params = {}
for key, value in p.get_parameter_shapes().items(): for key, value in p.get_parameter_shapes().items():
...@@ -27,22 +48,38 @@ def run(p): ...@@ -27,22 +48,38 @@ def run(p):
return migraphx.from_gpu(p.run(params)) return migraphx.from_gpu(p.run(params))
def test_input():
data = list(range(4))
m = create_buffer('f', data)
if sys.version_info >= (3, 0):
a = migraphx.argument(m)
check_shapes(a, m)
assert_eq(a.tolist(), data)
else:
a1 = migraphx.argument(m)
a2 = migraphx.argument(bytearray(a1))
check_shapes(a2, m)
assert_eq(a1.tolist(), m.tolist())
def test_output():
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
p = migraphx.parse_onnx("conv_relu_maxpool.onnx") r1 = run(p)
p.compile(migraphx.get_target("gpu")) r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
r1 = run(p) check_argument(r1)
r2 = run(p) check_argument(r2)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
check_list(r1) m1 = memoryview(r1)
check_list(r2) m2 = memoryview(r2)
m1 = memoryview(r1) check_shapes(r1, m1)
m2 = memoryview(r2) check_shapes(r2, m2)
assert_eq(r1.get_shape().elements(), reduce(lambda x,y: x*y,get_lens(m1), 1))
assert_eq(r1.get_shape().lens(), get_lens(m1))
assert_eq(r1.get_shape().strides(), get_strides(m1))
test_input()
test_output()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment