test_gpu_offload.py 407 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
import migraphx

p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
print("Compiling ...")
p.compile(migraphx.get_target("gpu"), offload_copy=False)
print(p)
params = {}

for key, value in p.get_parameter_shapes().items():
    print("Parameter {} -> {}".format(key, value))
    params[key] = migraphx.to_gpu(migraphx.generate_argument(value))

14
r = migraphx.from_gpu(p.run(params)[-1])
15
print(r)