README.md 10.3 KB
Newer Older
dcuai's avatar
dcuai committed
1
# ViT
suily's avatar
suily committed
2
3
4
5
## 论文
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
## 模型结构
suily's avatar
suily committed
6
ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
suily's avatar
suily committed
7
8
9
10
11
12
<div align=center>
    <img src="./doc/vit.png"/>
</div>

## 算法原理
整个模型结构可以分为五个步骤进行:
suily's avatar
suily committed
13

suily's avatar
suily committed
14
1、将图片切分成多个patch。
suily's avatar
suily committed
15

suily's avatar
suily committed
16
2、将得到的patches经过一个线性映射层后得到多个token embedding。
suily's avatar
suily committed
17

suily's avatar
suily committed
18
3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。
suily's avatar
suily committed
19

suily's avatar
suily committed
20
4、 将相加后的结果传入Transformer Encoder模块。
suily's avatar
suily committed
21

suily's avatar
suily committed
22
23
24
25
26
5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。

<div align=center>
    <img src="./doc/vit.png"/>
</div>
suily's avatar
suily committed
27

suily's avatar
suily committed
28
29
30
## 代码改动说明
ps:仓库中是改动后的代码,不需再次修改
```
suily's avatar
suily committed
31
1、vision_transformer/vit_jax/train.py,修改train_and_evaluate函数:
suily's avatar
suily committed
32
33
input_pipeline.get_datasets(config) # TODO:添加,解决直接get_dataset_info加载信息失败的问题
```
suily's avatar
suily committed
34

suily's avatar
suily committed
35
36
37
38
## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
suily's avatar
suily committed
39
docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro <imageID> /bin/bash  # <imageID>为以上拉取的docker的镜像ID替换
suily's avatar
suily committed
40
41

cd /your_code_path/vision_transformer
suily's avatar
suily committed
42
43
pip install flax==0.6.9 # flax会强制安装某版本ai包
pip install -r requirements.txt
suily's avatar
suily committed
44
45
46
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
suily's avatar
suily committed
47
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
48
49
wget https://download.sourcefind.cn:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://download.sourcefind.cn:65024/directlink/4/jax/DAS1.1/jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
suily's avatar
suily committed
50
51
52
53
54
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Dockerfile(方法二)
```
suily's avatar
suily committed
55
docker build --no-cache -t vit:latest .
suily's avatar
suily committed
56
docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro vit /bin/bash
suily's avatar
suily committed
57
58

cd /your_code_path/vision_transformer
suily's avatar
suily committed
59
60
pip install flax==0.6.9 # flax会强制安装某版本ai包
pip install -r requirements.txt
suily's avatar
suily committed
61
62
63
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
suily's avatar
suily committed
64
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
65
66
wget https://download.sourcefind.cn:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://download.sourcefind.cn:65024/directlink/4/jax/DAS1.1/jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
suily's avatar
suily committed
67
68
69
70
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Anaconda(方法三)
chenzk's avatar
chenzk committed
71
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.sourcefind.cn/tool/
suily's avatar
suily committed
72
73
74
75
76
77
78
79
80
81
```
DTK软件栈:dtk24.04.1
python:python3.10
jax:0.4.23
```
`Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应`

2、其他非特殊库直接按照下面步骤进行安装
```
cd /your_code_path/vision_transformer
suily's avatar
suily committed
82
pip install flax==0.6.9 # flax会强制安装某版本ai包
suily's avatar
suily committed
83
pip install -r requirements.txt
suily's avatar
suily committed
84
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
85
86
```
## 数据集
suily's avatar
suily committed
87
### 训练数据集
suily's avatar
suily committed
88
`cifar10  cifar100`
suily's avatar
suily committed
89
数据集根据训练命令由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
suily's avatar
suily committed
90

suily's avatar
suily committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改
```
vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py
搜索_is_gcs_disabled,修改为_is_gcs_disabled = True
```
数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下:
```
 ── cifar10
    │   ├── 3.0.2
    │             ├── cifar10-test.tfrecord-00000-of-00001
    │             ├── cifar10-train.tfrecord-00000-of-00001
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
 ── cifar100
    │   └── 3.0.2
    │             ├── cifar100-test.tfrecord-00000-of-00001
    │             ├── cifar100-train.tfrecord-00000-of-00001
    │             ├── coarse_label.labels.txt
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
```
suily's avatar
suily committed
114
### 推理数据集
chenzk's avatar
chenzk committed
115
推理所用图片和文件可根据以下代码进行下载:
suily's avatar
suily committed
116
117
118
119
120
121
122
123
124
125
126
```
# ./dataset是存储地址,可自订
wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -P ./dataset
wget https://picsum.photos/384 -O ./dataset/picsum.jpg  # 将图片调整为384分辨率
```
数据集目录结构如下:
```
 ── dataset
    │   ├── ilsvrc2012_wordnet_lemmas.txt
    │   └── picsum.jpg
```
suily's avatar
suily committed
127
## 训练
chenzk's avatar
chenzk committed
128
检查点可通过以下方式进行下载:
suily's avatar
suily committed
129
130
131
```
cd /your_code_path/vision_transformer/test_result   # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
suily's avatar
suily committed
132
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz
suily's avatar
suily committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
```
### 单机单卡
```
cd /your_code_path/vision_transformer
sh test.sh

# workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录
# config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集
# config.pretrained_dir=$(pwd)/test_result # 检查点所在目录
# config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64)
# config.total_steps=500 # 微调轮次
# config.warmup_steps=50 # 学习率衰减轮次
# config.batch=512 # 训练批次
# config.pp.crop=384 # 图像块的分辨率
# config.optim_dtype='bfloat16' # 精度
```
## 推理
chenzk's avatar
chenzk committed
150
检查点可通过以下方式进行下载:
suily's avatar
suily committed
151
152
153
154
```
cd /your_code_path/vision_transformer/test_result   # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz -O ViT-B_16_imagenet2012.npz
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16.npz -O ViT-L_16_imagenet2012.npz
suily's avatar
suily committed
155
```
suily's avatar
suily committed
156
157
```
cd /your_code_path/vision_transformer
suily's avatar
suily committed
158
python test.py  # 文件内可修改模型目录和数据集目录
suily's avatar
suily committed
159
160
```
## result
suily's avatar
suily committed
161
测试图为:
suily's avatar
suily committed
162
<div align=center>
suily's avatar
suily committed
163
    <img src="./doc/picsum.jpg"/>
suily's avatar
suily committed
164
</div>
suily's avatar
suily committed
165

suily's avatar
suily committed
166
```
suily's avatar
suily committed
167
----ViT-B_16:
suily's avatar
suily committed
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
dcu推理结果:
0.73861 : alp
0.24576 : valley, vale
0.00416 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00055 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
gpu推理结果:
0.73976 : alp
0.24465 : valley, vale
0.00414 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00054 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
suily's avatar
suily committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

----ViT-L_16:
dcu推理结果:
0.87382 : alp
0.11846 : valley, vale
0.00550 : cliff, drop, drop-off
0.00023 : mountain_tent
0.00017 : promontory, headland, head, foreland
0.00015 : lakeside, lakeshore
0.00013 : dam, dike, dyke
0.00006 : volcano
0.00006 : ski
0.00004 : sandbar, sand_bar
gpu推理结果:
0.87399 : alp
0.11828 : valley, vale
0.00550 : cliff, drop, drop-off
0.00023 : mountain_tent
0.00017 : promontory, headland, head, foreland
0.00015 : lakeside, lakeshore
0.00013 : dam, dike, dyke
0.00006 : volcano
0.00006 : ski
0.00004 : sandbar, sand_bar
suily's avatar
suily committed
214
```
suily's avatar
suily committed
215
### 精度
suily's avatar
suily committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
k800*1(1410Mhz,80G,cuda11.8):
| 参数 | acc     | loss     |
| -------------------------------- | ------- | -------- |
| model_datasets='b16,cifar10'<br>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98047 | 0.428023 |
| model_datasets='b16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.89206 | 1.25078  |
| model_datasets='l16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98890 | 0.348941 |
| model_datasets='l16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.91375 | 1.05141  |

k100*1(1270Mhz,64G,dtk24.04.1):
| 参数                                                         | acc     | loss     |
| ------------------------------------------------------------ | ------- | -------- |
| model_datasets='b16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98037 | 0.43239  |
| model_datasets='b16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.89001 | 1.2273   |
| model_datasets='l16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98921 | 0.306221 |
suily's avatar
suily committed
230
| model_datasets='l16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.91447 | 0.976117 |
suily's avatar
suily committed
231
232
## 应用场景
### 算法类别
suily's avatar
suily committed
233
`图像分类`
suily's avatar
suily committed
234
235
### 热点应用行业
`制造,电商,医疗,广媒,教育`
suily's avatar
suily committed
236
237
## 预训练权重
- https://console.cloud.google.com/storage/browser/vit_models/imagenet21k/  (微调用)
suily's avatar
suily committed
238

chenzk's avatar
chenzk committed
239
- https://console.cloud.google.com/storage/browser/vit_models/imagenet21k+imagenet2012/  (推理用)
suily's avatar
suily committed
240
## 源码仓库及问题反馈
chenzk's avatar
chenzk committed
241
- https://developer.sourcefind.cn/codes/modelzoo/vision_transformer_jax
suily's avatar
suily committed
242
## 参考资料
suily's avatar
suily committed
243
- https://github.com/google-research/vision_transformer