Commit 27ca76f4 authored by Paul's avatar Paul
Browse files

Fix tests on python 3

parent df2c494a
...@@ -16,8 +16,8 @@ endfunction() ...@@ -16,8 +16,8 @@ endfunction()
add_dependencies(tests migraphx_py) add_dependencies(tests migraphx_py)
add_dependencies(check migraphx_py) add_dependencies(check migraphx_py)
add_py_test(cpu cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(cpu test_cpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
if(MIGRAPHX_ENABLE_GPU) if(MIGRAPHX_ENABLE_GPU)
add_py_test(gpu gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(gpu test_gpu.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_py_test(array array.py WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_py_test(array test_array.py WORKING_DIRECTORY ${TEST_ONNX_DIR})
endif() endif()
...@@ -19,11 +19,11 @@ def write_float(b, index): ...@@ -19,11 +19,11 @@ def write_float(b, index):
def nelements(lens): def nelements(lens):
return reduce(lambda x,y: x*y,lens, 1) return reduce(lambda x,y: x*y,lens, 1)
def create_buffer(t, data): def create_buffer(t, data, shape):
a = array.array(t, data) a = array.array(t, data)
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
m = memoryview(a.tobytes()) m = memoryview(a.tobytes())
return m.cast(t) return m.cast(t, shape)
else: else:
m = memoryview(a.tostring()) m = memoryview(a.tostring())
return m return m
...@@ -48,14 +48,20 @@ def run(p): ...@@ -48,14 +48,20 @@ def run(p):
return migraphx.from_gpu(p.run(params)) return migraphx.from_gpu(p.run(params))
def test_shape(shape):
data = list(range(nelements(shape)))
m = create_buffer('f', data, shape)
a = migraphx.argument(m)
check_shapes(a, m)
assert_eq(a.tolist(), data)
def test_input(): def test_input():
data = list(range(4))
m = create_buffer('f', data)
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
a = migraphx.argument(m) test_shape([4])
check_shapes(a, m) # test_shape([2, 3])
assert_eq(a.tolist(), data)
else: else:
data = list(range(4))
m = create_buffer('f', data, [4])
a1 = migraphx.argument(m) a1 = migraphx.argument(m)
a2 = migraphx.argument(bytearray(a1)) a2 = migraphx.argument(bytearray(a1))
check_shapes(a2, m) check_shapes(a2, m)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment