# ResNet50-QAT 本项目旨在对ResNet50模型执行量化感知训练,将其转换为onnx模型,并在TensorRT上运行。 ## 论文 **Deep Residual Learning for Image Recognition** * https://arxiv.org/pdf/1512.03385.pdf ## 模型结构 ResNet50包含深度残差网络,卷积,池化,全局平均池化以及分类层。 深度残差网络(ResNet): ResNet 引入了残差连接(residual connection)的概念,通过在网络中添加跨层的直接连接来解决深度神经网络中的梯度消失问题。这种连接允许信息在网络中更容易地向后传播,从而使得可以训练更深的网络模型。 深度结构: ResNet-50 具有深度的网络结构,包括 50 层卷积和全连接层。它包含多个残差块(residual blocks),每个残差块内部有多个卷积层和恒等映射(identity mapping)。通过这种方式,ResNet-50 能够学习到更复杂和抽象的特征表示。 卷积层和池化层: ResNet-50 使用了一系列卷积层和池化层,这些层用于从输入图像中提取特征。卷积层通过滤波器(filter)对输入图像进行卷积操作,从而提取图像中的局部特征。池化层则用于降低特征图的空间分辨率,减少模型的参数数量。 全局平均池化层: 在 ResNet-50 的最后一层卷积层之后,通常会添加一个全局平均池化层(Global Average Pooling Layer)。该层将特征图中的每个通道的特征取平均值,生成一个固定大小的特征向量作为输入,用于最终的分类任务。 分类层: 最后,ResNet-50 使用一个全连接的分类层,该层将特征向量映射到预定义的类别标签上。通常在分类层之前还会添加一个或多个全连接层和激活函数,用于增加模型的非线性能力。 ![alt text](readme_imgs/image-1.png) ## 算法原理 ResNet 引入了残差连接(residual connection)的概念,通过在网络中添加跨层的直接连接来传递信息,解决深度神经网络训练过程中的梯度消失和梯度爆炸问题,从而使得可以训练更深的网络模型。 ![alt text](readme_imgs/image-2.png) ## 环境配置 ### Anaconda (方法一) 1、本项目目前仅支持在N卡环境运行 python 3.9.18 torch 2.0.1 cuda 11 pip install -r requirements.txt pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization 2、TensorRT wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip unzip [下载的压缩包] -d [解压路径] pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec 注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。 ## 数据集 本项目使用CIFAR-10数据集,可直接运行`main.py`后自动下载并处理。 ## 训练 # --epochs表示训练或校准回合数 # --resume表示继续训练 # --qat表示校准(在训练基础模型时不能使用此参数) CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --qat --batch_size=N --lr=X --num_classes=10 ## 推理 # N卡推理 trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8 python eval.py --device=0 # DCU卡推理 python evaluate_migraphx.py --device=0 ## result ![alt text](readme_imgs/image-3.png) ### 精度 ||原始模型(A800)|QAT模型(A800)|ONNX模型(A800)|TensorRT模型(A800)|MIGraphX模型| |:---|:---|:---|:---|:---|----| |Acc|0.9589|0.9584|0.9588|0.9584|| |推理时间|7.6061s|42.9348s|10.4021s|2.2839s|| ## 应用场景 ### 算法类别 `图像分类` ### 热点应用行业 `制造,交通,网安` ## 源码仓库及问题反馈 * https://developer.hpccube.com/codes/modelzoo/resnet50-qat_pytorch ## 参考资料 * https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html