Classifier.py 4.25 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
# -*- coding: utf-8 -*-
"""
分类器示例
"""
import cv2
6
import argparse
yangql's avatar
yangql committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import onnxruntime as ort

def Preprocessing(pathOfImage):
    # 读取图像
    image = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)             
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # 调整大小,使短边为256,保持长宽比
    ratio = float(256) / min(image.shape[0], image.shape[1])
    if image.shape[0] > image.shape[1]:
        new_size = [int(round(ratio * image.shape[0])), 256]
    else:
        new_size = [256, int(round(ratio * image.shape[1]))]
    image = np.array(cv2.resize(image, (new_size[1],new_size[0])))
    
    # 裁剪中心窗口为224*224
    h, w, c = image.shape
    start_x = w//2 - 224//2
    start_y = h//2 - 224//2
    image = image[start_y:start_y+224, start_x:start_x+224, :]
    
    # transpose
    image = image.transpose(2, 0, 1)
    
    # 将输入数据转换为float32
    img_data = image.astype('float32')
    
    # normalize
    mean_vec = np.array([123.675, 116.28, 103.53])
    stddev_vec = np.array([58.395, 57.12, 57.375])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
        norm_img_data[i,:,:] = (img_data[i,:,:] - mean_vec[i]) / stddev_vec[i]
    
    # 调整尺寸
    norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
    return norm_img_data

yangql's avatar
yangql committed
46
def postprocess(scores,pathOfImage):
yangql's avatar
yangql committed
47
48
49
50
51
52
53
54
55
56
57
    '''
    Postprocessing with mxnet gluon
    The function takes scores generated by the network and returns the class IDs in decreasing order
    of probability
    '''
    with open('../Resource/synset.txt', 'r') as f:
        labels = [l.rstrip() for l in f]
    preds = np.squeeze(scores)
    a = np.argsort(preds)[::-1]
    print('class=%s ; probability=%f' %(labels[a[0]],preds[a[0]]))

yangql's avatar
yangql committed
58
59
60
    text = 'class=%s ' % (labels[a[0]])
    saveimage(pathOfImage,text)

61
def ort_seg_dcu(model_path,image,staticInfer,dynamicInfer):
yangql's avatar
yangql committed
62
    
63
64
    provider_options=[]
    if staticInfer:
liucong's avatar
liucong committed
65
        provider_options=[{'device_id':'0','migraphx_fp16_enable':'true'}]
yangql's avatar
yangql committed
66

67
68
    if dynamicInfer:
        provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'true', 'migraphx_profile_max_shapes':'data:1x3x224x224'}]
yangql's avatar
yangql committed
69

70
    dcu_session = ort.InferenceSession(model_path, providers=['MIGraphXExecutionProvider'], provider_options=provider_options)
liucong's avatar
liucong committed
71
72
73
74
75
    images = [image]
    input_nodes = dcu_session.get_inputs()
    input_names = [i_n.name for i_n in input_nodes]
    output_nodes = dcu_session.get_outputs()
    output_names = [o_n.name for o_n in output_nodes]
yangql's avatar
yangql committed
76

liucong's avatar
liucong committed
77
78
79
80
81
82
    input_dict = {}
    for i_d, i_n in zip(images, input_names):
        input_dict[i_n] = i_d

    result = dcu_session.run([], input_dict)
    scores=np.array(result[0])
yangql's avatar
yangql committed
83
84
85
86
    print("ort result.shape:",scores.shape)

    return scores

yangql's avatar
yangql committed
87
88
89
90
def saveimage(pathOfImage,text):
    iimage = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
yangql's avatar
yangql committed
91
    font_color = (0, 0, 255)  
yangql's avatar
yangql committed
92
    font_thickness = 1
yangql's avatar
yangql committed
93
    text_position = (5, 20)
yangql's avatar
yangql committed
94
95
96
97
    cv2.putText(iimage, text, text_position, font, font_scale, font_color, font_thickness)
    cv2.imwrite("./output_image.jpg", iimage)
    cv2.destroyAllWindows()

yangql's avatar
yangql committed
98
if __name__ == '__main__':
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgPath', type=str, default='../Resource/Images/ImageNet_01.jpg', help="image path")
    parser.add_argument('--staticModelPath', type=str, default='../Resource/Models/resnet50_static.onnx', help="static onnx filepath")
    parser.add_argument('--dynamicModelPath', type=str, default='../Resource/Models/resnet50_dynamic.onnx', help="dynamic onnx filepath")
    parser.add_argument("--staticInfer",action="store_true",default=False,help="Performing static inference")
    parser.add_argument("--dynamicInfer",action="store_true",default=False,help="Performing dynamic inference")
    args = parser.parse_args()

    # 数据预处理
    image = Preprocessing(args.imgPath)

    # 静态推理
    if args.staticInfer:
        result = ort_seg_dcu(args.staticModelPath,image,args.staticInfer,args.dynamicInfer)

    # 动态推理
    if args.dynamicInfer:
        result = ort_seg_dcu(args.dynamicModelPath,image,args.staticInfer,args.dynamicInfer)
yangql's avatar
yangql committed
117
118
    
    # 解析分类结果
119
    postprocess(result,args.imgPath)