onnx2mxr.py 630 Bytes
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import migraphx
import os


def main():
    
    onnx_model_dir = "./models/model_1/onnx"
    mxr_model_dir = "./models/model_1/mxr2"
    
    for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]:
        model_path = os.path.join(onnx_model_dir, f"model-static-batch-size-{batch_size}.onnx")
        model = migraphx.parse_onnx(model_path)
        print(f"compile {model_path}")
        model.compile(migraphx.get_target("gpu"), offload_copy=False, device_id=0)
        
        migraphx.save(model, os.path.join(mxr_model_dir, f"model-static-batch-size-{batch_size}.mxr"))

if __name__ == "__main__":
    main()