README.md 2.43 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# VGG16-QAT

本项目旨在对VGG16模型执行量化感知训练,将其转换为onnx模型,并在TensorRT上运行。

## 论文

**Very Deep Convolutional Networks for Large-Scale Image Recognition**

* https://arxiv.org/abs/1409.1556

## 模型结构

VGG网络由小的卷积滤波器组成,VGG16有三个全连接层和13个卷积层,此外,可在模型中添加`BatchNorm`以及`Dropout`层。

![Alt text](readme_imgs/image-1.png)

## 算法原理

VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中过度拟合的倾向。3×3的滤波器是最佳大小,因为较小的大小无法捕捉左右和上下的信息。

![alt text](readme_imgs/image-2.png)

## 环境配置

### Anaconda (方法一)

1、本项目目前仅支持在N卡环境运行

    python 3.9.18
    torch 2.0.1
    cuda 11
    
    pip install -r requirements.txt

    pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization

2、TensorRT

    wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip 

    unzip [下载的压缩包] -d [解压路径]

    pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl

    ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec


注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。

## 数据集

本项目使用CIFAR-10数据集,可直接运行`main.py`后自动下载并处理。

## 训练

    # --epochs表示训练或校准回合数
    # --resume表示继续训练
mashun1's avatar
mashun1 committed
58
59
    # --qat表示校准(在训练基础模型时不能使用此参数)
    CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --qat --batch_size=N --lr=X --num_classes=10
mashun1's avatar
mashun1 committed
60
61
62

## 推理

mashun1's avatar
mashun1 committed
63
    trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/qat/last.trt --int8
mashun1's avatar
mashun1 committed
64
65
66
67
68
69
70
71
72

    python eval.py --device=0

## result

![alt text](readme_imgs/image-3.png)

### 精度

mashun1's avatar
mashun1 committed
73
||原始模型|QAT模型|ONNX模型|TensorRT模型|
mashun1's avatar
mashun1 committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|

## 应用场景

### 算法类别

`图像分类`

### 热点应用行业

`制造,交通,网安`

## 源码仓库及问题反馈

* https://developer.hpccube.com/codes/modelzoo/vgg16-qat_pytorch

## 参考资料

* https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html