"vscode:/vscode.git/clone" did not exist on "9c36ddcd2b7a1a2e1f5b5362a379538916de23cd"
DeepLabV3.py 5.9 KB
Newer Older
shangxl's avatar
shangxl committed
1
2
3
import numpy as np
import cv2
import migraphx
4
5
import argparse
import os
shangxl's avatar
shangxl committed
6
7
8
9

def Preprocessing(pil_img, newW, newH):
    assert newW > 0 and newH > 0, 'Scale is too small' 
    img_nd = cv2.cvtColor(pil_img, cv2.COLOR_BGR2RGB)    # BGR转换为RGB
10
    img_nd = cv2.resize(img_nd, (newW, newH))            # 将图像尺寸修改为newW x newH
shangxl's avatar
shangxl committed
11
12
13
14
15
16
17
18
19
20
21
22
23
    
    if len(img_nd.shape) == 2:                           # 获取图像的维度信息
        img_nd = np.expand_dims(img_nd, axis=2)          # 如果是2维的 扩充为3维  

    img_trans = img_nd.transpose((2, 0, 1))              # HWC转换为CHW     
    img_trans = np.expand_dims(img_trans, 0)             # CHW扩展为NCHW
    img_trans = np.ascontiguousarray(img_trans)          # 保证内存连续存储
    img_trans = img_trans.astype(np.float32)             # 转换成浮点型数据
    if img_trans.max() > 1:                             
        img = img_trans / 255.0                          # 保证数据处于0-1之间的浮点数

    return img

24
25
26
27
28
29
def AllocateOutputMemory(model):
    outputData = {}
    for key in model.get_outputs().keys():
        outputData[key] = migraphx.allocate_gpu(s=model.get_outputs()[key])

    return outputData    
shangxl's avatar
shangxl committed
30
31
32

# 对通道维度执行Softmax
def softmax(arr):
33

shangxl's avatar
shangxl committed
34
35
36
37
38
39
    exp_vals = np.exp(arr - np.max(arr, axis=1, keepdims=True))
    sum_exp = np.sum(exp_vals, axis=1, keepdims=True)
    return exp_vals / sum_exp  

if __name__ == '__main__':

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    parser = argparse.ArgumentParser()
    parser.add_argument("loadMode", type=int,help="0:DeepLabV3 Single Image Sample.\t 1:DeepLabV3 Multiple Image Sample.")
    parser.add_argument("--enable_offload_copy", action="store_true")
    precision_group = parser.add_mutually_exclusive_group()
    precision_group.add_argument("--int8",action="store_true")
    precision_group.add_argument("--fp16",action="store_true")

    args = parser.parse_args()
    loadMode = args.loadMode
    useInt8 = args.int8
    useFP16 = args.fp16
    offloadCopy = args.enable_offload_copy

    #加载图片方式
    if loadMode == 0:
        maxInput={"input":[1,3,513,513]}
        img = cv2.imread("../Resource/Images/000001.jpg")
        input_img = Preprocessing(img, 513, 513)
    else:
        maxInput={"input":[3,3,513,513]}
        folder_path = "../Resource/Images/"
        image_extensions = ('.jpg')
        image_list = []
        for filename in os.listdir(folder_path):    
            # 检查文件是否为图片
            if filename.lower().endswith(image_extensions):
                file_path = os.path.join(folder_path, filename)
                img = cv2.imread(file_path)
                image_list.append(Preprocessing(img, 513, 513))
        input_img = np.concatenate(image_list,axis=0)

shangxl's avatar
shangxl committed
71
72
73
74
75
76
    # 加载模型
    model = migraphx.parse_onnx("../Resource/Models/deeplabv3_resnet101.onnx", map_input_dims=maxInput)

    # 获取模型输入/输出节点信息
    inputs = model.get_inputs()
    outputs = model.get_outputs()
77
78
    inputName = model.get_parameter_names()[0]
    inputShape = inputs[inputName].lens()
shangxl's avatar
shangxl committed
79

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    #量化
    if useInt8:
        dic = dict()
        calibrate_folder_path = "../Resource/Images/"
        calibrate_image_extensions = ('.jpg')
        calibrate_image_list = []
        for filename in os.listdir(calibrate_folder_path):    
            # 检查文件是否为图片
            if filename.lower().endswith(calibrate_image_extensions):
                file_path = os.path.join(calibrate_folder_path, filename)
                img = cv2.imread(file_path)
                calibrate_image_list.append(Preprocessing(img, 513, 513))
        calibrate_img = np.concatenate(calibrate_image_list,axis=0)
        dic[inputName] = migraphx.argument(calibrate_img)
        calibration = [dic]
        migraphx.quantize_int8(model, migraphx.get_target("gpu"), calibration)
    if useFP16: 
        migraphx.quantize_fp16(model)
    
    if offloadCopy :
        # 编译模型
        model.compile(migraphx.get_target("gpu"), device_id=0)      # device_id: 设置GPU设备,默认为0号设备
        # 模型推理
        mask = model.run({'input':input_img})      
        result = mask[0]                                            # 得到第一个输出节点的结果    
    else:
        # 编译模型
        model.compile(migraphx.get_target("gpu"),offload_copy=False, device_id=0)      # device_id: 设置GPU设备,默认为0号设备
        modelData = AllocateOutputMemory(model)                                          # 为输出节点分配device内存,用于保存输出数据
        modelData[inputName] = migraphx.to_gpu(migraphx.argument(input_img))
        # 推理
        mask = model.run(modelData)
        result = migraphx.from_gpu(mask[0])                         # 获取第1个输出节点的数据,migraphx.argument类型
shangxl's avatar
shangxl committed
113
114
115
116

    # 对通道维度进行softmax
    softmax_result = softmax(result)                            
    # 计算通道维度最大值对应的索引(即类别索引)      
117
    max_indices = np.argmax(softmax_result, axis=1)      
shangxl's avatar
shangxl committed
118
119
120
121
122
123
124
125
126
    # 使用预设颜色
    color_map = np.array([
        [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255],       # 0-3类
        [255, 255, 0], [255, 0, 255], [0, 255, 255], [128, 0, 0],  # 4-7类
        [0, 128, 0], [0, 0, 128], [128, 128, 0], [128, 0, 128],  # 8-11类
        [0, 128, 128], [192, 192, 192], [128, 128, 128], [64, 0, 0],  # 12-15类
        [0, 64, 0], [0, 0, 64], [64, 64, 0], [64, 0, 64],       # 16-19类
        [0, 64, 64]                                             # 20类
    ], dtype=np.uint8)
127
128
129
130
131
132
133
    for i in range(max_indices.shape[0]):
        flat_index = max_indices[i]
        rgb_image = color_map[flat_index]                          # # 将二维的类别索引图直接转换为三维的 RGB 彩色图像
        fileName  = "Result_"+str(i+1)+".jpg"
        cv2.imwrite(fileName, rgb_image)                            # 保存图像分割结果

    print("Segmentation results have been saved to Python directory")
shangxl's avatar
shangxl committed
134
135