import migraphx
import os


def main():
    
    onnx_model_dir = "./models/model_1/onnx"
    mxr_model_dir = "./models/model_1/mxr2"
    
    for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]:
        model_path = os.path.join(onnx_model_dir, f"model-static-batch-size-{batch_size}.onnx")
        model = migraphx.parse_onnx(model_path)
        print(f"compile {model_path}")
        model.compile(migraphx.get_target("gpu"), offload_copy=False, device_id=0)
        
        migraphx.save(model, os.path.join(mxr_model_dir, f"model-static-batch-size-{batch_size}.mxr"))

if __name__ == "__main__":
    main()