README.md 2.45 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# VGG16-QAT

本项目旨在对VGG16模型执行量化感知训练,将其转换为onnx模型,并在TensorRT上运行。

## 论文

**Very Deep Convolutional Networks for Large-Scale Image Recognition**

* https://arxiv.org/abs/1409.1556

## 模型结构

VGG网络由小的卷积滤波器组成,VGG16有三个全连接层和13个卷积层,此外,可在模型中添加`BatchNorm`以及`Dropout`层。

![Alt text](readme_imgs/image-1.png)

## 算法原理

VGG使用多个较小的卷积滤波器,这减少了网络在训练过程中过度拟合的倾向。3×3的滤波器是最佳大小,因为较小的大小无法捕捉左右和上下的信息。

![alt text](readme_imgs/image-2.png)

## 环境配置

### Anaconda (方法一)

1、本项目目前仅支持在N卡环境运行

    python 3.9.18
    torch 2.0.1
    cuda 11
    
    pip install -r requirements.txt

    pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com pytorch-quantization

2、TensorRT

    wget https://github.com/NVIDIA/TensorRT/archive/refs/tags/8.5.3.zip 

    unzip [下载的压缩包] -d [解压路径]

    pip install 解压路径/python/tensorrt-8.5.3.1-cp39-none-linux_x86_64.whl

    ln -s 解压路径(绝对路径)/bin/trtexec /usr/local/bin/trtexec


注意:若需要`cu12`则将`requirements.txt`中的相关注释关闭,并安装。

## 数据集

本项目使用CIFAR-10数据集,可直接运行`main.py`后自动下载并处理。

## 训练

    # --epochs表示训练或校准回合数
    # --resume表示继续训练
    # --calibrate表示校准(在训练基础模型时不能使用此参数)
    CUDA_VISIBLE_DEVICES=0,1 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 main.py --epochs=N --resume --calibrate --batch_size=N --lr=X --num_classes=10

## 推理

mashun1's avatar
mashun1 committed
63
    trtexec --onnx=/path/to/onnx --saveEngine=./checkpoints/calibrated/last.trt --int8
mashun1's avatar
mashun1 committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    python eval.py --device=0

## result

![alt text](readme_imgs/image-3.png)

### 精度

||原始模型|QAT校准模型|ONNX模型|TensorRT模型|
|:---|:---|:---|:---|:---|
|Acc|0.9189|0.9185|0.9181|0.9184|
|推理时间|5.5764s|13.7603s|4.2848s|2.9893s|

## 应用场景

### 算法类别

`图像分类`

### 热点应用行业

`制造,交通,网安`

## 源码仓库及问题反馈

* https://developer.hpccube.com/codes/modelzoo/vgg16-qat_pytorch

## 参考资料

* https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/index.html