evaluate_migraphx.py 2.09 KB
Newer Older
liucong's avatar
liucong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import numpy as np
import migraphx
import torch
import time
from tqdm import tqdm
from utils.data import prepare_dataloader

def eval_migraphx(onnx_path, dataloader, device):
    # 加载模型
    model = migraphx.parse_onnx(onnx_path)

    # 获取模型输入/输出节点信息
    inputs = model.get_inputs()
    outputs = model.get_outputs()
    inputName = model.get_parameter_names()[0]
    inputShape = inputs[inputName].lens()

    # 编译模型
    model.compile(t=migraphx.get_target("gpu"), device_id=device)

    correct, total = 0, 0
    for it in range(2):
        desc = "warmup"
        if it == 1:
            start_time = time.time()
            desc = "eval onnx model"
            
        for data, label in tqdm(dataloader, desc=desc, total=len(dataloader)):
            data, label = data.numpy().astype(np.float32), label.numpy().astype(np.float32)
            results = model.run({inputName:data})
            predictions = np.argmax(results[0], axis=-1)
            
            correct += (label == predictions).sum()
            total += len(label)
        
        if it == 1:
            end_time = time.time()
    
    return correct / total, end_time - start_time

def main(args):
    device = torch.device(f"cuda:{args.device}" if args.device != -1 else "cpu")
    
    test_dataloader, _ = prepare_dataloader("./data/cifar10", False, args.batch_size)
    
    # 测试onnx模型
    acc_onnx, runtime_onnx = eval_migraphx("./checkpoints/calibrated/pretrained_qat.onnx", test_dataloader, args.device)
    
    print("==============================================================")
    print(f"MIGraphX Model Acc: {acc_onnx}, Inference Time: {runtime_onnx:.4f}s")
    print("==============================================================")

if __name__ == '__main__':
    import argparse
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--batch_size", type=int, default=16)
    
    parser.add_argument("--device", type=int, default=-1)
    
    parser.add_argument("--num_classes", type=int, default=10)
    
    args = parser.parse_args()
    
    main(args)