"vscode:/vscode.git/clone" did not exist on "016033188c73fe876791995d739c8471dd0b1fc3"
Commit e141854c authored by Paul's avatar Paul
Browse files

Fix bug when reading from buffer

parent 6bc81428
......@@ -103,7 +103,7 @@ migraphx::shape to_shape(const py::buffer_info& info)
auto strides = info.strides;
std::transform(strides.begin(), strides.end(), strides.begin(), [&](auto i) -> std::size_t {
if(n > 0)
return n * i;
return n / i;
else
return 0;
});
......
import migraphx, struct
import migraphx, struct, array, sys
try:
from functools import reduce
except:
pass
def assert_eq(x, y):
if x == y:
......@@ -6,20 +10,37 @@ def assert_eq(x, y):
else:
raise Exception(str(x) + " != " + str(y))
def get_lens(m):
return list(m.shape)
def get_strides(m):
return [s/m.itemsize for s in m.strides]
def read_float(b, index):
return struct.unpack_from('f', b, index*4)[0]
def check_list(a):
def write_float(b, index):
struct.pack_into('f', b, index*4)
def nelements(lens):
return reduce(lambda x,y: x*y,lens, 1)
def create_buffer(t, data):
a = array.array(t, data)
if sys.version_info >= (3, 0):
m = memoryview(a.tobytes())
return m.cast(t)
else:
m = memoryview(a.tostring())
return m
def check_argument(a):
l = a.tolist()
for i in range(len(l)):
assert_eq(l[i], read_float(a, i))
def check_shapes(r, m):
lens = list(m.shape)
strides = [s/m.itemsize for s in m.strides]
elements = nelements(lens)
assert_eq(r.get_shape().elements(), elements)
assert_eq(r.get_shape().lens(), lens)
assert_eq(r.get_shape().strides(), strides)
def run(p):
params = {}
for key, value in p.get_parameter_shapes().items():
......@@ -27,22 +48,38 @@ def run(p):
return migraphx.from_gpu(p.run(params))
def test_input():
data = list(range(4))
m = create_buffer('f', data)
if sys.version_info >= (3, 0):
a = migraphx.argument(m)
check_shapes(a, m)
assert_eq(a.tolist(), data)
else:
a1 = migraphx.argument(m)
a2 = migraphx.argument(bytearray(a1))
check_shapes(a2, m)
assert_eq(a1.tolist(), m.tolist())
def test_output():
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))
r1 = run(p)
r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
r1 = run(p)
r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())
check_argument(r1)
check_argument(r2)
check_list(r1)
check_list(r2)
m1 = memoryview(r1)
m2 = memoryview(r2)
m1 = memoryview(r1)
m2 = memoryview(r2)
check_shapes(r1, m1)
check_shapes(r2, m2)
assert_eq(r1.get_shape().elements(), reduce(lambda x,y: x*y,get_lens(m1), 1))
assert_eq(r1.get_shape().lens(), get_lens(m1))
assert_eq(r1.get_shape().strides(), get_strides(m1))
test_input()
test_output()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment