infer_pytroch.py 2.03 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from sys import argv
import os
import cv2
import json
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from vit_model import vit_base_patch16_224_in21k as create_model


def main(intputdir):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

    # read class_indict
    json_path = './models/class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = create_model(num_classes=5, has_logits=False).to(device)
    # load model weights
    model_weight_path = "./models/model.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))

    # load image
    Img_cnt = 0
    match_cnt = 0
    for filename in os.listdir(intputdir):
        Img_cnt += 1
        img_path = intputdir + filename
        assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
        img = Image.open(img_path)
        # [N, C, H, W]
        img = data_transform(img)
        print(img)
        break
        # expand batch dimension
        img = torch.unsqueeze(img, dim=0)

        model.eval()
        with torch.no_grad():
            # predict class
            output = torch.squeeze(model(img.to(device))).cpu()
            predict = torch.softmax(output, dim=0)

        for i in range(len(predict)):
            if(predict[i].numpy() >= 0.5):
                print("class: {:10}   prob: {:.3}".format(class_indict[str(i)], predict[i].numpy()))
                if intputdir.find(class_indict[str(i)]):
                    match_cnt += 1

    print("Img_cnt: {:<5} match_cnt: {:<5} acc:{:.3}".format(Img_cnt, match_cnt, match_cnt/Img_cnt))

if __name__ == '__main__':
    main(argv[1])