test_gpu.py 383 Bytes
Newer Older
Paul's avatar
Paul committed
1
import migraphx
Paul's avatar
Paul committed
2

Khalique's avatar
Khalique committed
3
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
Paul's avatar
Paul committed
4
5
print(p)
print("Compiling ...")
Paul's avatar
Paul committed
6
p.compile(migraphx.get_target("gpu"))
Paul's avatar
Paul committed
7
print(p)
Paul's avatar
Paul committed
8
params = {}
kahmed10's avatar
kahmed10 committed
9

Paul's avatar
Paul committed
10
for key, value in p.get_parameter_shapes().items():
Paul's avatar
Paul committed
11
    print("Parameter {} -> {}".format(key, value))
Paul's avatar
Paul committed
12
    params[key] = migraphx.to_gpu(migraphx.generate_argument(value))
Paul's avatar
Paul committed
13

Paul's avatar
Paul committed
14
r = migraphx.from_gpu(p.run(params))
Paul's avatar
Paul committed
15
print(r)