# 分类器 ## 模型简介 本示例使用了经典的mnist模型,模型下载地址:https://github.com/onnx/models/blob/main/vision/classification/mnist/model/mnist-12.onnx,模型结构如下图所示(可以通过netron (https://netron.app/) 查看),该模型的输入shape为[1,1,28,28] ,数据排布为NCHW,输出是10个类别的概率(未归一化)。 ![image-20221212165226581](../Images/Classifier_01.png) ​ ## 预处理 在将数据输入到模型之前,需要对图像做如下预处理操作: 1. 转换为单通道灰度图 2. resize到28x28 3. 将像素值归一化到[0.0, 1.0] 4. 转换数据排布为NCHW 本示例代码采用了OpenCV的cv::dnn::blobFromImages()函数实现了预处理操作: ``` ErrorCode Classifier::Classify(const std::vector &srcImages,std::vector> &predictions) { ... // 预处理 cv::Mat inputBlob; cv::dnn::blobFromImages(srcImages,// 输入数据,支持多张图像 inputBlob, // 输出数据 scale, // 缩放系数,这里为1/255.0 inputSize, // 模型输入大小,这里为28x28 meanValue, // 均值,这里不需要减均值,所以设置为0.0 swapRB, // 单通道图像,这里设置为0 false); ... } ``` cv::dnn::blobFromImages()函数支持多个输入图像,首先将输入图像resize到inputSize,然后减去均值meanValue,最后乘以scale并转换为NCHW,最终将转换好的数据保存到inputBlob中,然后就可以输入到模型中执行推理了。 ## 推理 完成预处理后,就可以执行推理了: ``` ErrorCode Classifier::Classify(const std::vector &srcImages,std::vector> &predictions) { ... // 预处理 // 输入数据 migraphx::parameter_map inputData; inputData[inputName]= migraphx::argument{inputShape, (float*)inputBlob.data}; // 推理 std::vector results = net.eval(inputData); // 获取输出节点的属性 migraphx::argument result = results[0]; // 获取第一个输出节点的数据 ... } ``` 1. inputData表示MIGraphX的输入数据,inputData是一个映射关系,每个输入节点名都会对应一个输入数据,如果有多个输入,则需要为每个输入节点名创建数据,inputName表示输入节点名,这里为Input3,migraphx::argument{inputShape, (float*)inputBlob.data}表示该节点名对应的数据,这里是通过前面预处理的数据inputBlob来创建的,第一个参数表示数据的shape,第二个参数表示数据指针。 2. net.eval(inputData)返回模型的推理结果,由于这里只有一个输出节点,所以std::vector中只有一个数据,results[0]表示第一个输出节点,这里对应Plus214_Output_0节点,获取输出数据之后,就可以对输出数据进行各种操作了。 3. 由于该模型输出的是一个未归一化的概率,所以如果需要得到每一类的实际的概率值,还需要计算softmax。 ## 运行示例 根据samples工程中的README.md构建成功C++ samples后,在build目录下输入如下命令运行该示例: ``` ./MIGraphX_Samples 0 ``` 输出结果为: ``` ... ========== 0 result ========== label:0,confidence:0.000000 label:1,confidence:0.000034 label:2,confidence:0.000012 label:3,confidence:0.000169 label:4,confidence:0.000044 label:5,confidence:0.000001 label:6,confidence:0.000000 label:7,confidence:0.000725 label:8,confidence:0.000278 label:9,confidence:0.998737 ``` 由于示例图像为数字9,所以结果中label为9的概率最高。 ![image-20221212173655309](../Images/Classifier_02.png)