import numpy as np
import cv2
import migraphx

def Preprocessing(pil_img, newW, newH):
    assert newW > 0 and newH > 0, 'Scale is too small' 
    img_nd = cv2.cvtColor(pil_img, cv2.COLOR_BGR2RGB)    # BGR转换为RGB
    img_nd = cv2.resize(img_nd, (newW, newH))            # 将图像尺寸修改为256x256
    
    if len(img_nd.shape) == 2:                           # 获取图像的维度信息
        img_nd = np.expand_dims(img_nd, axis=2)          # 如果是2维的 扩充为3维  

    img_trans = img_nd.transpose((2, 0, 1))              # HWC转换为CHW     
    img_trans = np.expand_dims(img_trans, 0)             # CHW扩展为NCHW
    img_trans = np.ascontiguousarray(img_trans)          # 保证内存连续存储
    img_trans = img_trans.astype(np.float32)             # 转换成浮点型数据
    if img_trans.max() > 1:                             
        img = img_trans / 255.0                          # 保证数据处于0-1之间的浮点数

    return img

def Sigmoid(x):
  return 1 / (1 + np.exp(-x))

# 对通道维度执行Softmax
def softmax(arr):
    # 1：对通道维度计算指数，避免数值溢出（减去最大值）
    exp_vals = np.exp(arr - np.max(arr, axis=1, keepdims=True))
    # 2：计算通道维度的指数和
    sum_exp = np.sum(exp_vals, axis=1, keepdims=True)
    # 3：归一化得到Softmax结果
    return exp_vals / sum_exp  

if __name__ == '__main__':

    # 设置最大输入shape
    maxInput={"inputs":[1,3,256,256]}
    
    # 加载模型
    model = migraphx.parse_onnx("../Resource/Models/deeplabv3_resnet101.onnx", map_input_dims=maxInput)

    # 获取模型输入/输出节点信息
    inputs = model.get_inputs()
    outputs = model.get_outputs()

    # 编译模型
    model.compile(migraphx.get_target("gpu"), device_id=0)      # device_id: 设置GPU设备，默认为0号设备

    # 图像预处理
    img = cv2.imread("../Resource/Images/000001.jpg")
    input_img = Preprocessing(img, 513, 513)

    print(inputs)
    # 模型推理
    mask = model.run({'images':input_img})      
    result = mask[0]                                            # 得到第一个输出节点的结果
    # 对通道维度进行softmax
    softmax_result = softmax(result)                            
    # 计算通道维度最大值对应的索引（即类别索引）      
    max_indices = np.argmax(softmax_result, axis=1)             # 等价于 np.argmax(arr, axis=1, keepdims=False)
    # 使用预设颜色
    color_map = np.array([
        [0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255],       # 0-3类
        [255, 255, 0], [255, 0, 255], [0, 255, 255], [128, 0, 0],  # 4-7类
        [0, 128, 0], [0, 0, 128], [128, 128, 0], [128, 0, 128],  # 8-11类
        [0, 128, 128], [192, 192, 192], [128, 128, 128], [64, 0, 0],  # 12-15类
        [0, 64, 0], [0, 0, 64], [64, 64, 0], [64, 0, 64],       # 16-19类
        [0, 64, 64]                                             # 20类
    ], dtype=np.uint8)

    flat_index = max_indices[0]                                 # 取第0批的数据 
    # 将二维的类别索引图直接转换为三维的 RGB 彩色图像
    rgb_image = color_map[flat_index]
    cv2.imwrite("Result.jpg", rgb_image)                         # 保存图像分割结果

