README.md 10.4 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
# ViT
## 论文
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
## 模型结构
suily's avatar
suily committed
6
ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
suily's avatar
suily committed
7
8
9
10
11
12
<div align=center>
    <img src="./doc/vit.png"/>
</div>

## 算法原理
整个模型结构可以分为五个步骤进行:
suily's avatar
suily committed
13

suily's avatar
suily committed
14
1、将图片切分成多个patch。
suily's avatar
suily committed
15

suily's avatar
suily committed
16
2、将得到的patches经过一个线性映射层后得到多个token embedding。
suily's avatar
suily committed
17

suily's avatar
suily committed
18
3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。
suily's avatar
suily committed
19

suily's avatar
suily committed
20
4、 将相加后的结果传入Transformer Encoder模块。
suily's avatar
suily committed
21

suily's avatar
suily committed
22
23
24
25
26
27
28
29
30
31
5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。

<div align=center>
    <img src="./doc/vit.png"/>
</div>

## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
suily's avatar
suily committed
32
docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro <imageID> /bin/bash  # <imageID>为以上拉取的docker的镜像ID替换
suily's avatar
suily committed
33
34

cd /your_code_path/vision_transformer
suily's avatar
suily committed
35
36
pip install flax==0.6.9 # flax会强制安装某版本ai包
pip install -r requirements.txt
suily's avatar
suily committed
37
38
39
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
suily's avatar
suily committed
40
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
41
42
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
suily's avatar
suily committed
43
44
45
46
47
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Dockerfile(方法二)
```
suily's avatar
suily committed
48
docker build --no-cache -t vit:latest .
suily's avatar
suily committed
49
docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro vit /bin/bash
suily's avatar
suily committed
50
51

cd /your_code_path/vision_transformer
suily's avatar
suily committed
52
53
pip install flax==0.6.9 # flax会强制安装某版本ai包
pip install -r requirements.txt
suily's avatar
suily committed
54
55
56
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
suily's avatar
suily committed
57
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
58
59
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl
suily's avatar
suily committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24.04.1
python:python3.10
jax:0.4.23
```
`Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应`

2、其他非特殊库直接按照下面步骤进行安装
```
cd /your_code_path/vision_transformer
suily's avatar
suily committed
75
pip install flax==0.6.9 # flax会强制安装某版本ai包
suily's avatar
suily committed
76
pip install -r requirements.txt
suily's avatar
suily committed
77
pip install tensorflow-cpu==2.14.0
suily's avatar
suily committed
78
79
```
## 数据集
suily's avatar
suily committed
80
### 训练数据集
suily's avatar
suily committed
81
`cifar10  cifar100`
suily's avatar
suily committed
82
数据集根据训练命令由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
suily's avatar
suily committed
83

suily's avatar
suily committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改
```
vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py
搜索_is_gcs_disabled,修改为_is_gcs_disabled = True
```
数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下:
```
 ── cifar10
    │   ├── 3.0.2
    │             ├── cifar10-test.tfrecord-00000-of-00001
    │             ├── cifar10-train.tfrecord-00000-of-00001
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
 ── cifar100
    │   └── 3.0.2
    │             ├── cifar100-test.tfrecord-00000-of-00001
    │             ├── cifar100-train.tfrecord-00000-of-00001
    │             ├── coarse_label.labels.txt
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
```
suily's avatar
suily committed
107
### 推理数据集
suily's avatar
suily committed
108
推理所用图片和文件可根据[scnet](http://113.200.138.88:18080/aidatasets/project-dependency/vision_transformer_jax)或以下代码进行下载:
suily's avatar
suily committed
109
110
111
112
113
114
115
116
117
118
119
```
# ./dataset是存储地址,可自订
wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -P ./dataset
wget https://picsum.photos/384 -O ./dataset/picsum.jpg  # 将图片调整为384分辨率
```
数据集目录结构如下:
```
 ── dataset
    │   ├── ilsvrc2012_wordnet_lemmas.txt
    │   └── picsum.jpg
```
suily's avatar
suily committed
120
## 训练
suily's avatar
suily committed
121
检查点可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax/-/tree/master/imagenet21k?ref_type=heads)或以下方式进行下载:
suily's avatar
suily committed
122
123
124
```
cd /your_code_path/vision_transformer/test_result   # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
suily's avatar
suily committed
125
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz
suily's avatar
suily committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
```
### 单机单卡
```
cd /your_code_path/vision_transformer
sh test.sh

# workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录
# config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集
# config.pretrained_dir=$(pwd)/test_result # 检查点所在目录
# config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64)
# config.total_steps=500 # 微调轮次
# config.warmup_steps=50 # 学习率衰减轮次
# config.batch=512 # 训练批次
# config.pp.crop=384 # 图像块的分辨率
# config.optim_dtype='bfloat16' # 精度
```
## 推理
suily's avatar
suily committed
143
检查点可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax/-/tree/master/imagenet21k+imagenet2012?ref_type=heads)或以下方式进行下载:
suily's avatar
suily committed
144
145
146
147
```
cd /your_code_path/vision_transformer/test_result   # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz -O ViT-B_16_imagenet2012.npz
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16.npz -O ViT-L_16_imagenet2012.npz
suily's avatar
suily committed
148
```
suily's avatar
suily committed
149
150
```
cd /your_code_path/vision_transformer
suily's avatar
suily committed
151
python test.py  # 文件内可修改模型目录和数据集目录
suily's avatar
suily committed
152
153
```
## result
suily's avatar
suily committed
154
测试图为:
suily's avatar
suily committed
155
<div align=center>
suily's avatar
suily committed
156
    <img src="./doc/picsum.jpg"/>
suily's avatar
suily committed
157
</div>
suily's avatar
suily committed
158

suily's avatar
suily committed
159
```
suily's avatar
suily committed
160
----ViT-B_16:
suily's avatar
suily committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
dcu推理结果:
0.73861 : alp
0.24576 : valley, vale
0.00416 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00055 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
gpu推理结果:
0.73976 : alp
0.24465 : valley, vale
0.00414 : lakeside, lakeshore
0.00404 : cliff, drop, drop-off
0.00094 : promontory, headland, head, foreland
0.00060 : mountain_tent
0.00054 : dam, dike, dyke
0.00033 : volcano
0.00031 : ski
0.00012 : solar_dish, solar_collector, solar_furnace
suily's avatar
suily committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

----ViT-L_16:
dcu推理结果:
0.87382 : alp
0.11846 : valley, vale
0.00550 : cliff, drop, drop-off
0.00023 : mountain_tent
0.00017 : promontory, headland, head, foreland
0.00015 : lakeside, lakeshore
0.00013 : dam, dike, dyke
0.00006 : volcano
0.00006 : ski
0.00004 : sandbar, sand_bar
gpu推理结果:
0.87399 : alp
0.11828 : valley, vale
0.00550 : cliff, drop, drop-off
0.00023 : mountain_tent
0.00017 : promontory, headland, head, foreland
0.00015 : lakeside, lakeshore
0.00013 : dam, dike, dyke
0.00006 : volcano
0.00006 : ski
0.00004 : sandbar, sand_bar
suily's avatar
suily committed
207
```
suily's avatar
suily committed
208
### 精度
suily's avatar
suily committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
k800*1(1410Mhz,80G,cuda11.8):
| 参数 | acc     | loss     |
| -------------------------------- | ------- | -------- |
| model_datasets='b16,cifar10'<br>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98047 | 0.428023 |
| model_datasets='b16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.89206 | 1.25078  |
| model_datasets='l16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98890 | 0.348941 |
| model_datasets='l16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.91375 | 1.05141  |

k100*1(1270Mhz,64G,dtk24.04.1):
| 参数                                                         | acc     | loss     |
| ------------------------------------------------------------ | ------- | -------- |
| model_datasets='b16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98037 | 0.43239  |
| model_datasets='b16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.89001 | 1.2273   |
| model_datasets='l16,cifar10'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.98921 | 0.306221 |
suily's avatar
suily committed
223
| model_datasets='l16,cifar100'<br/>config.batch=512<br/>config.total_steps=500<br/>config.optim_dtype = 'bfloat16' | 0.91447 | 0.976117 |
suily's avatar
suily committed
224
225
## 应用场景
### 算法类别
suily's avatar
suily committed
226
`图像分类`
suily's avatar
suily committed
227
228
### 热点应用行业
`制造,电商,医疗,广媒,教育`
suily's avatar
suily committed
229
230
231
## 预训练权重
- http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax
- https://console.cloud.google.com/storage/browser/vit_models/imagenet21k/  (微调用)
suily's avatar
suily committed
232

suily's avatar
suily committed
233
  https://console.cloud.google.com/storage/browser/vit_models/imagenet21k+imagenet2012/  (推理用)
suily's avatar
suily committed
234
235
236
## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/vision_transformer_jax
## 参考资料
suily's avatar
suily committed
237
- https://github.com/google-research/vision_transformer