infer_pytroch_migraphx.py 3.05 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from sys import argv
import os
import cv2
import json
import numpy as np
import migraphx
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from vit_model import vit_base_patch16_224_in21k as create_model


def main(intputdir):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    # read class_indict
    json_path = './models/class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create pth model
    model = create_model(num_classes=5, has_logits=False).to(device)
    # load model weights
    model_weight_path = "./models/model.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))

    # create onnx model
    model1 = migraphx.parse_onnx('./models/model.onnx')
    model1.compile(t=migraphx.get_target("gpu"), device_id=1)

    # load image
    Img_cnt = 0
    pytorch_match_cnt = 0
    migraphx_match_cnt = 0
    for filename in os.listdir(intputdir):
        Img_cnt += 1
        img_path = intputdir + filename
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)

        # pythorch eval
        model.eval()

        # migraphx run
        img1 = img.numpy()
        inputName = model1.get_parameter_names()[0]
        inputShape = model1.get_parameter_shapes()[inputName].lens()
        results = model1.run({inputName: migraphx.argument(img1)})
        # migraphx postprocess
        infer_res_exp = np.exp(results[0])
        predict = infer_res_exp / infer_res_exp.sum(axis=1)
        for i in range(len(predict[0])):
            if(predict[0][i] >= 0.5):
                print("migraphx class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[0][i]))
                if (intputdir.find(class_indict[str(i)])):
                    migraphx_match_cnt += 1

        # pythorch postprocess
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)

        for i in range(len(predict)):
            if(predict[i].numpy() >= 0.5):
lijian6's avatar
lijian6 committed
80
                print("pytorch class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
lijian6's avatar
lijian6 committed
81
82
83
                if intputdir.find(class_indict[str(i)]):
                    pytorch_match_cnt += 1

lijian6's avatar
lijian6 committed
84
    print("Pytorch Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, pytorch_match_cnt, pytorch_match_cnt/Img_cnt))
lijian6's avatar
lijian6 committed
85
86
87
88
89
    print("Migraphx Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, migraphx_match_cnt, migraphx_match_cnt/Img_cnt))


if __name__ == '__main__':
    main(argv[1])