infer_pytroch_migraphx.py 3.05 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from sys import argv
import os
import cv2
import json
import numpy as np
import migraphx
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from vit_model import vit_base_patch16_224_in21k as create_model


def main(intputdir):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    # read class_indict
    json_path = './models/class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create pth model
    model = create_model(num_classes=5, has_logits=False).to(device)
    # load model weights
    model_weight_path = "./models/model.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))

    # create onnx model
lijian6's avatar
lijian6 committed
38
    model1 = migraphx.parse_onnx('../Models/model.onnx')
lijian6's avatar
lijian6 committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    model1.compile(t=migraphx.get_target("gpu"), device_id=1)

    # load image
    Img_cnt = 0
    pytorch_match_cnt = 0
    migraphx_match_cnt = 0
    for filename in os.listdir(intputdir):
        Img_cnt += 1
        img_path = intputdir + filename
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        # [N, C, H, W]
        img = data_transform(img)
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)

        # pythorch eval
        model.eval()

        # migraphx run
        img1 = img.numpy()
        inputName = model1.get_parameter_names()[0]
        inputShape = model1.get_parameter_shapes()[inputName].lens()
        results = model1.run({inputName: migraphx.argument(img1)})
        # migraphx postprocess
        infer_res_exp = np.exp(results[0])
        predict = infer_res_exp / infer_res_exp.sum(axis=1)
        for i in range(len(predict[0])):
            if(predict[0][i] >= 0.5):
                print("migraphx class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[0][i]))
                if (intputdir.find(class_indict[str(i)])):
                    migraphx_match_cnt += 1

        # pythorch postprocess
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)

        for i in range(len(predict)):
            if(predict[i].numpy() >= 0.5):
lijian6's avatar
lijian6 committed
80
                print("pytorch class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
lijian6's avatar
lijian6 committed
81
82
83
                if intputdir.find(class_indict[str(i)]):
                    pytorch_match_cnt += 1

lijian6's avatar
lijian6 committed
84
    print("Pytorch Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, pytorch_match_cnt, pytorch_match_cnt/Img_cnt))
lijian6's avatar
lijian6 committed
85
86
87
88
89
    print("Migraphx Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, migraphx_match_cnt, migraphx_match_cnt/Img_cnt))


if __name__ == '__main__':
    main(argv[1])