Classifier.py 3.12 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# -*- coding: utf-8 -*-
"""
分类器示例
"""
import cv2
import numpy as np
import onnxruntime as ort

def Preprocessing(pathOfImage):
    # 读取图像
    image = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)             
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # 调整大小,使短边为256,保持长宽比
    ratio = float(256) / min(image.shape[0], image.shape[1])
    if image.shape[0] > image.shape[1]:
        new_size = [int(round(ratio * image.shape[0])), 256]
    else:
        new_size = [256, int(round(ratio * image.shape[1]))]
    image = np.array(cv2.resize(image, (new_size[1],new_size[0])))
    
    # 裁剪中心窗口为224*224
    h, w, c = image.shape
    start_x = w//2 - 224//2
    start_y = h//2 - 224//2
    image = image[start_y:start_y+224, start_x:start_x+224, :]
    
    # transpose
    image = image.transpose(2, 0, 1)
    
    # 将输入数据转换为float32
    img_data = image.astype('float32')
    
    # normalize
    mean_vec = np.array([123.675, 116.28, 103.53])
    stddev_vec = np.array([58.395, 57.12, 57.375])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
        norm_img_data[i,:,:] = (img_data[i,:,:] - mean_vec[i]) / stddev_vec[i]
    
    # 调整尺寸
    norm_img_data = norm_img_data.reshape(1, 3, 224, 224).astype('float32')
    return norm_img_data

yangql's avatar
yangql committed
45
def postprocess(scores,pathOfImage):
yangql's avatar
yangql committed
46
47
48
49
50
51
52
53
54
55
56
    '''
    Postprocessing with mxnet gluon
    The function takes scores generated by the network and returns the class IDs in decreasing order
    of probability
    '''
    with open('../Resource/synset.txt', 'r') as f:
        labels = [l.rstrip() for l in f]
    preds = np.squeeze(scores)
    a = np.argsort(preds)[::-1]
    print('class=%s ; probability=%f' %(labels[a[0]],preds[a[0]]))

yangql's avatar
yangql committed
57
58
59
    text = 'class=%s ' % (labels[a[0]])
    saveimage(pathOfImage,text)

yangql's avatar
yangql committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def ort_seg_dcu(model_path,image):
    
    #创建sess_options
    sess_options = ort.SessionOptions()

    #设置图优化
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

    #是否开启profiling
    sess_options.enable_profiling = False
    dcu_session = ort.InferenceSession(model_path,sess_options,providers=['ROCMExecutionProvider'],)
    input_name=dcu_session.get_inputs()[0].name

    results = dcu_session.run(None, input_feed={input_name:image })
    scores=np.array(results[0])
    print("ort result.shape:",scores.shape)

    return scores

yangql's avatar
yangql committed
79
80
81
82
83
84
85
86
87
88
89
def saveimage(pathOfImage,text):
    iimage = cv2.imread(pathOfImage, cv2.IMREAD_COLOR)
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.5
    font_color = (255, 0, 0)  
    font_thickness = 1
    text_position = (5, 10)
    cv2.putText(iimage, text, text_position, font, font_scale, font_color, font_thickness)
    cv2.imwrite("./output_image.jpg", iimage)
    cv2.destroyAllWindows()

yangql's avatar
yangql committed
90
91
92
93
94
95
96
97
98
99
if __name__ == '__main__':
    
    pathOfImage ="../Resource/Images/ImageNet_01.jpg"
    image = Preprocessing(pathOfImage)
    model_path = "../Resource/Models/resnet50-v2-7.onnx"
    
    # 推理
    result = ort_seg_dcu(model_path,image)
    
    # 解析分类结果
yangql's avatar
yangql committed
100
    postprocess(result,pathOfImage)