README.md 6.1 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
# ViT
## 论文
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
## 模型结构
ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。
由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
<div align=center>
    <img src="./doc/vit.png"/>
</div>

## 算法原理
整个模型结构可以分为五个步骤进行:
suily's avatar
suily committed
14

suily's avatar
suily committed
15
1、将图片切分成多个patch。
suily's avatar
suily committed
16

suily's avatar
suily committed
17
2、将得到的patches经过一个线性映射层后得到多个token embedding。
suily's avatar
suily committed
18

suily's avatar
suily committed
19
3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。
suily's avatar
suily committed
20

suily's avatar
suily committed
21
4、 将相加后的结果传入Transformer Encoder模块。
suily's avatar
suily committed
22

suily's avatar
suily committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。

<div align=center>
    <img src="./doc/vit.png"/>
</div>

## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash

cd /your_code_path/vision_transformer
pip install -r requirements.txt  # flax会强制安装某版本ai包
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
pip install tensorflow-cpu==2.13.1
wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Dockerfile(方法二)
```
cd ./docker
docker build --no-cache -t vision_transformer:latest .
docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G  --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash

cd /your_code_path/vision_transformer
pip install -r requirements.txt
pip install -r requirements.txt  # flax会强制安装某版本ai包
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
pip install tensorflow-cpu==2.13.1
wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24.04.1
python:python3.10
jax:0.4.23
```
`Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应`

2、其他非特殊库直接按照下面步骤进行安装
```
cd /your_code_path/vision_transformer
pip install -r requirements.txt
pip install tensorflow-cpu==2.13.1
```
## 数据集
`cifar10  cifar100`
数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改
```
vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py
搜索_is_gcs_disabled,修改为_is_gcs_disabled = True
```
数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下:
```
 ── cifar10
    │   ├── 3.0.2
    │             ├── cifar10-test.tfrecord-00000-of-00001
    │             ├── cifar10-train.tfrecord-00000-of-00001
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
 ── cifar100
    │   └── 3.0.2
    │             ├── cifar100-test.tfrecord-00000-of-00001
    │             ├── cifar100-train.tfrecord-00000-of-00001
    │             ├── coarse_label.labels.txt
    │             ├── dataset_info.json
    │             ├── features.json
    │             └── label.labels.txt
```
## 训练
检查点可通过以下方式进行下载:
```
cd /your_code_path/vision_transformer/test_result   # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
```
### 单机单卡
```
cd /your_code_path/vision_transformer
sh test.sh

# workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录
# config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集
# config.pretrained_dir=$(pwd)/test_result # 检查点所在目录
# config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64)
# config.total_steps=500 # 微调轮次
# config.warmup_steps=50 # 学习率衰减轮次
# config.batch=512 # 训练批次
# config.pp.crop=384 # 图像块的分辨率
# config.optim_dtype='bfloat16' # 精度
```
## 推理
```
python test.py
```
## result
此处填算法效果测试图(包括输入、输出)

<div align=center>
    <img src="./doc/xxx.png"/>
</div>

### 精度
测试数据:[test data](链接),使用的加速卡:xxx。

根据测试结果情况填写表格:
| xxx | xxx | xxx | xxx | xxx |
| :------: | :------: | :------: | :------: |:------: |
| xxx | xxx | xxx | xxx | xxx  |
| xxx | xx | xxx | xxx | xxx |
## 应用场景
### 算法类别
`图像识别`
### 热点应用行业
`制造,电商,医疗,广媒,教育`
## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/vision_transformer_jax
## 参考资料
- https://github.com/FoundationVision/VAR