array.py 1.03 KB
Newer Older
Paul's avatar
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import migraphx, struct

def assert_eq(x, y):
    if x == y:
        pass
    else:
        raise Exception(str(x) + " != " + str(y))

def get_lens(m):
    return list(m.shape)

def get_strides(m):
    return [s/m.itemsize for s in m.strides]

def read_float(b, index):
    return struct.unpack_from('f', b, index*4)[0]

Paul's avatar
Paul committed
18
def check_list(a):
Paul's avatar
Paul committed
19
20
    l = a.tolist()
    for i in range(len(l)):
Paul's avatar
Paul committed
21
        assert_eq(l[i], read_float(a, i))
Paul's avatar
Paul committed
22
23
24
25
26
27
28
29

def run(p):
    params = {}
    for key, value in p.get_parameter_shapes().items():
        params[key] = migraphx.to_gpu(migraphx.generate_argument(value))

    return migraphx.from_gpu(p.run(params))

30
31
32
33

p = migraphx.parse_onnx("conv_relu_maxpool.onnx")
p.compile(migraphx.get_target("gpu"))

Paul's avatar
Paul committed
34
35
36
37
38
r1 = run(p)
r2 = run(p)
assert_eq(r1, r2)
assert_eq(r1.tolist(), r2.tolist())

Paul's avatar
Paul committed
39
40
check_list(r1)
check_list(r2)
Paul's avatar
Paul committed
41
42
43
44
45
46
47

m1 = memoryview(r1)
m2 = memoryview(r2)

assert_eq(r1.get_shape().elements(), reduce(lambda x,y: x*y,get_lens(m1), 1))
assert_eq(r1.get_shape().lens(), get_lens(m1))
assert_eq(r1.get_shape().strides(), get_strides(m1))
48