"vscode:/vscode.git/clone" did not exist on "d92c69f4a713bc540606265127d60e496054d5bf"
Classifier.py 4.03 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
# -*- coding: utf-8 -*-
"""
分类器示例
"""
import cv2
6
import argparse
yangql's avatar
yangql committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
import onnxruntime as ort

def Preprocessing(pathOfImage):
    # 读取图像
    image = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)             
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # 调整大小,使短边为256,保持长宽比
    ratio = float(256) / min(image.shape[0], image.shape[1])
    if image.shape[0] > image.shape[1]:
        new_size = [int(round(ratio * image.shape[0])), 256]
    else:
        new_size = [256, int(round(ratio * image.shape[1]))]
    image = np.array(cv2.resize(image, (new_size[1],new_size[0])))
    
    # 裁剪中心窗口为224*224
    h, w, c = image.shape
    start_x = w//2 - 224//2
    start_y = h//2 - 224//2
    image = image[start_y:start_y+224, start_x:start_x+224, :]
    
    # transpose
    image = image.transpose(2, 0, 1)
    
    # 将输入数据转换为float32
    img_data = image.astype('float32')
    
    # normalize
    mean_vec = np.array([123.675, 116.28, 103.53])
    stddev_vec = np.array([58.395, 57.12, 57.375])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
        norm_img_data[i,:,:] = (img_data[i,:,:] - mean_vec[i]) / stddev_vec[i]
    
    # 调整尺寸
    norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
    return norm_img_data

yangql's avatar
yangql committed
46
def postprocess(scores,pathOfImage):
yangql's avatar
yangql committed
47
48
49
50
51
52
53
54
55
56
57
    '''
    Postprocessing with mxnet gluon
    The function takes scores generated by the network and returns the class IDs in decreasing order
    of probability
    '''
    with open('../Resource/synset.txt', 'r') as f:
        labels = [l.rstrip() for l in f]
    preds = np.squeeze(scores)
    a = np.argsort(preds)[::-1]
    print('class=%s ; probability=%f' %(labels[a[0]],preds[a[0]]))

yangql's avatar
yangql committed
58
59
60
    text = 'class=%s ' % (labels[a[0]])
    saveimage(pathOfImage,text)

61
def ort_seg_dcu(model_path,image,staticInfer,dynamicInfer):
yangql's avatar
yangql committed
62
    
63
64
65
    provider_options=[]
    if staticInfer:
        provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'false'}]
yangql's avatar
yangql committed
66

67
68
    if dynamicInfer:
        provider_options=[{'device_id':'0','migraphx_fp16_enable':'true','dynamic_model':'true', 'migraphx_profile_max_shapes':'data:1x3x224x224'}]
yangql's avatar
yangql committed
69

70
    dcu_session = ort.InferenceSession(model_path, providers=['MIGraphXExecutionProvider'], provider_options=provider_options)
yangql's avatar
yangql committed
71
72
    input_name=dcu_session.get_inputs()[0].name

73
    results = dcu_session.run(None, input_feed={input_name:image})
yangql's avatar
yangql committed
74
75
76
77
78
    scores=np.array(results[0])
    print("ort result.shape:",scores.shape)

    return scores

yangql's avatar
yangql committed
79
80
81
82
def saveimage(pathOfImage,text):
    iimage = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
yangql's avatar
yangql committed
83
    font_color = (0, 0, 255)  
yangql's avatar
yangql committed
84
    font_thickness = 1
yangql's avatar
yangql committed
85
    text_position = (5, 20)
yangql's avatar
yangql committed
86
87
88
89
    cv2.putText(iimage, text, text_position, font, font_scale, font_color, font_thickness)
    cv2.imwrite("./output_image.jpg", iimage)
    cv2.destroyAllWindows()

yangql's avatar
yangql committed
90
if __name__ == '__main__':
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    parser = argparse.ArgumentParser()
    parser.add_argument('--imgPath', type=str, default='../Resource/Images/ImageNet_01.jpg', help="image path")
    parser.add_argument('--staticModelPath', type=str, default='../Resource/Models/resnet50_static.onnx', help="static onnx filepath")
    parser.add_argument('--dynamicModelPath', type=str, default='../Resource/Models/resnet50_dynamic.onnx', help="dynamic onnx filepath")
    parser.add_argument("--staticInfer",action="store_true",default=False,help="Performing static inference")
    parser.add_argument("--dynamicInfer",action="store_true",default=False,help="Performing dynamic inference")
    args = parser.parse_args()

    # 数据预处理
    image = Preprocessing(args.imgPath)

    # 静态推理
    if args.staticInfer:
        result = ort_seg_dcu(args.staticModelPath,image,args.staticInfer,args.dynamicInfer)

    # 动态推理
    if args.dynamicInfer:
        result = ort_seg_dcu(args.dynamicModelPath,image,args.staticInfer,args.dynamicInfer)
yangql's avatar
yangql committed
109
110
    
    # 解析分类结果
111
    postprocess(result,args.imgPath)