import argparse import numpy as np import migraphx import torch import time from tqdm import tqdm from utils.data import prepare_dataloader def eval_migraphx(onnx_path, dataloader, device): # 加载模型 model = migraphx.parse_onnx(onnx_path) # 获取模型输入/输出节点信息 inputs = model.get_inputs() outputs = model.get_outputs() inputName = model.get_parameter_names()[0] inputShape = inputs[inputName].lens() # 编译模型 model.compile(t=migraphx.get_target("gpu"), device_id=device) correct, total = 0, 0 for it in range(2): desc = "warmup" if it == 1: start_time = time.time() desc = "eval onnx model" for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)): data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32) results = model.run({inputName:data}) predictions = np.argmax(results[0], axis=-1) correct += (label == predictions).sum() total += len(label) if it == 1: end_time = time.time() return correct / total, end_time - start_time def main(args): device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu") test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size) # 测试onnx模型 acc_onnx, runtime_onnx = eval_migraphx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device) print("==============================================================") print(f"MIGraphX Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s") print("==============================================================") if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--device", type=int, default=-1) parser.add_argument("--num_classes", type=int, default=10) args = parser.parse_args() main(args)