Commit e4399a52 authored by suily's avatar suily
Browse files

添加README等

parent 5498e94a
# ViT
## 论文
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
## 模型结构
ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。
由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
<div align=center>
<img src="./doc/vit.png"/>
</div>
## 算法原理
整个模型结构可以分为五个步骤进行:
1、将图片切分成多个patch。
2、将得到的patches经过一个线性映射层后得到多个token embedding。
3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。
4、 将相加后的结果传入Transformer Encoder模块。
5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。
<div align=center>
<img src="./doc/vit.png"/>
</div>
## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10
docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash
cd /your_code_path/vision_transformer
pip install -r requirements.txt # flax会强制安装某版本ai包
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
pip install tensorflow-cpu==2.13.1
wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Dockerfile(方法二)
```
cd ./docker
docker build --no-cache -t vision_transformer:latest .
docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash
cd /your_code_path/vision_transformer
pip install -r requirements.txt
pip install -r requirements.txt # flax会强制安装某版本ai包
pip uninstall tensorflow
pip uninstall jax
pip uninstall jaxlib
pip install tensorflow-cpu==2.13.1
wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl
pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24.04.1
python:python3.10
jax:0.4.23
```
`Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应`
2、其他非特殊库直接按照下面步骤进行安装
```
cd /your_code_path/vision_transformer
pip install -r requirements.txt
pip install tensorflow-cpu==2.13.1
```
## 数据集
`cifar10 cifar100`
数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py
注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改
```
vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py
搜索_is_gcs_disabled,修改为_is_gcs_disabled = True
```
数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下:
```
── cifar10
│   ├── 3.0.2
│    ├── cifar10-test.tfrecord-00000-of-00001
│    ├── cifar10-train.tfrecord-00000-of-00001
│    ├── dataset_info.json
│    ├── features.json
│ └── label.labels.txt
── cifar100
│   └── 3.0.2
│    ├── cifar100-test.tfrecord-00000-of-00001
│    ├── cifar100-train.tfrecord-00000-of-00001
│    ├── coarse_label.labels.txt
│    ├── dataset_info.json
│    ├── features.json
│ └── label.labels.txt
```
## 训练
检查点可通过以下方式进行下载:
```
cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
```
### 单机单卡
```
cd /your_code_path/vision_transformer
sh test.sh
# workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录
# config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集
# config.pretrained_dir=$(pwd)/test_result # 检查点所在目录
# config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64)
# config.total_steps=500 # 微调轮次
# config.warmup_steps=50 # 学习率衰减轮次
# config.batch=512 # 训练批次
# config.pp.crop=384 # 图像块的分辨率
# config.optim_dtype='bfloat16' # 精度
```
## 推理
```
python test.py
```
## result
此处填算法效果测试图(包括输入、输出)
<div align=center>
<img src="./doc/xxx.png"/>
</div>
### 精度
测试数据:[test data](链接),使用的加速卡:xxx。
根据测试结果情况填写表格:
| xxx | xxx | xxx | xxx | xxx |
| :------: | :------: | :------: | :------: |:------: |
| xxx | xxx | xxx | xxx | xxx |
| xxx | xx | xxx | xxx | xxx |
## 应用场景
### 算法类别
`图像识别`
### 热点应用行业
`制造,电商,医疗,广媒,教育`
## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/vision_transformer_jax
## 参考资料
- https://github.com/FoundationVision/VAR
# 模型唯一标识
modelCode = 916
# 模型名称
modelName= vision_transformer_jax
# 模型描述
modelDescription=Google提出的一种图像识别模型,应用了无decoder的纯transformer结构(不依赖CNN)
# 应用场景
appScenario=推理,训练,图像识别,制造,电商,医疗,广媒,教育
# 框架类型
frameType=jax
...@@ -10,7 +10,7 @@ git+https://github.com/google/flaxformer ...@@ -10,7 +10,7 @@ git+https://github.com/google/flaxformer
ml-collections>=0.1.0 ml-collections>=0.1.0
numpy>=1.19.5 numpy>=1.19.5
pandas>=1.1.0 pandas>=1.1.0
tensorflow-cpu>=2.13.0 # tensorflow-cpu>=2.4.0 # Using tensorflow-cpu to have all GPU memory for JAX. tensorflow-cpu==2.13.1 # tensorflow-cpu>=2.4.0 # Using tensorflow-cpu to have all GPU memory for JAX.
tensorflow-datasets>=4.0.1 tensorflow-datasets>=4.0.1
tensorflow-probability>=0.11.1 tensorflow-probability>=0.11.1
# tensorflow-text>=2.9.0 # tensorflow-text>=2.9.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment