test_gpu.py 382 Bytes
Newer Older
Paul's avatar
Paul committed
1
import migraphx
Paul's avatar
Paul committed
2

Khalique's avatar
Khalique committed
3
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
Paul's avatar
Paul committed
4
5
print(p)
print("Compiling ...")
Paul's avatar
Paul committed
6
p.compile(migraphx.get_target("gpu"))
Paul's avatar
Paul committed
7
print(p)
Paul's avatar
Paul committed
8
9
params = {}
for key, value in p.get_parameter_shapes().items():
Paul's avatar
Paul committed
10
    print("Parameter {} -> {}".format(key, value))
Paul's avatar
Paul committed
11
    params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
Paul's avatar
Paul committed
12

Paul's avatar
Paul committed
13
r = migraphx.from_gpu(p.run(params))
Paul's avatar
Paul committed
14
print(r)