Unet.py 2.06 KB
Newer Older
1
2
3
4
import numpy as np
import cv2
import migraphx

liucong's avatar
liucong committed
5
6
def Preprocessing(pil_img, newW, newH):
    assert newW > 0 and newH > 0, 'Scale is too small' 
7
8
9
    img_nd = cv2.cvtColor(pil_img, cv2.COLOR_BGR2RGB)    # BGR转换为RGB
    img_nd = cv2.resize(img_nd, (newW, newH))            # 将图像尺寸修改为256x256
    
10
11
12
    if len(img_nd.shape) == 2:
        img_nd = np.expand_dims(img_nd, axis=2)

13
14
15
16
17
18
19
    img_trans = img_nd.transpose((2, 0, 1))              # HWC转换为CHW
    img_trans = np.expand_dims(img_trans, 0)             # CHW扩展为NCHW
    img_trans = np.ascontiguousarray(img_trans)          # 保证内存连续存储
    img_trans = img_trans.astype(np.float32)             # 转换成浮点型数据
    if img_trans.max() > 1:                             
        img = img_trans / 255.0                          # 保证数据处于0-1之间的浮点数

liucong's avatar
liucong committed
20
    return img
21
22
23
24
25
26
27

def Sigmoid(x):
  return 1 / (1 + np.exp(-x))

if __name__ == '__main__':
    
    # 加载模型
liucong's avatar
liucong committed
28
    model = migraphx.parse_onnx("../../Resource/Models/unet_13_256.onnx")
29
30
31
32
33
34
35
36
    inputName = model.get_parameter_names()
    inputShape = model.get_parameter_shapes()
    print("inputName:{0} \ninputShape:{1}".format(inputName, inputShape))

    # 编译模型
    model.compile(migraphx.get_target("gpu"), device_id=0)   # device_id: 设置GPU设备,默认为0号设备

    # 图像预处理
liucong's avatar
liucong committed
37
38
    img = cv2.imread("../../Resource/Images/car1.jpeg")
    input_img = Preprocessing(img, 256, 256)
39
40
41
42
43
44
45
46
47
48
49
50

    # 模型推理
    mask = model.run({'inputs':input_img})      
    output_mask = np.array(mask[0])[0]                        # 获取推理结果,shape为(1,256,256)
    probs = Sigmoid(output_mask)                              # 计算sigmoid值

     # 0/1像素值,当大于0.996时,值为255,小于等于0.996时,值为0
    output_mask[probs > 0.996] = 255
    output_mask[probs <= 0.996] = 0

    output = output_mask.astype(np.uint8)[0]                  # 将浮点型转换为uint8整型,shape为(256,256)
    cv2.imwrite("output.jpg", output)                         # 保存图像分割结果