# ViT ## 论文 `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 ## 模型结构 ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
## 算法原理 整个模型结构可以分为五个步骤进行: 1、将图片切分成多个patch。 2、将得到的patches经过一个线性映射层后得到多个token embedding。 3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。 4、 将相加后的结果传入Transformer Encoder模块。 5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。
## 环境配置 ### Docker(方法一) ``` docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10 docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro /bin/bash # 为以上拉取的docker的镜像ID替换 cd /your_code_path/vision_transformer pip install flax==0.6.9 # flax会强制安装某版本ai包 pip install -r requirements.txt pip uninstall tensorflow pip uninstall jax pip uninstall jaxlib pip install tensorflow-cpu==2.14.0 wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl ``` ### Dockerfile(方法二) ``` docker build --no-cache -t vit:latest . docker run -it --network=host --privileged=true --name=vit --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro vit /bin/bash cd /your_code_path/vision_transformer pip install flax==0.6.9 # flax会强制安装某版本ai包 pip install -r requirements.txt pip uninstall tensorflow pip uninstall jax pip uninstall jaxlib pip install tensorflow-cpu==2.14.0 wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.1/jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl wget https://cancon.hpccube.com:65024/directlink/4/jax/DAS1.0/jaxlib-0.4.23+das1.0+git97306ab.abi1.dtk2404-cp310-cp310-manylinux2014_x86_64.whl pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl ``` ### Anaconda(方法三) 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/ ``` DTK软件栈:dtk24.04.1 python:python3.10 jax:0.4.23 ``` `Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应` 2、其他非特殊库直接按照下面步骤进行安装 ``` cd /your_code_path/vision_transformer pip install flax==0.6.9 # flax会强制安装某版本ai包 pip install -r requirements.txt pip install tensorflow-cpu==2.14.0 ``` ## 数据集 ### 训练数据集 `cifar10 cifar100` 数据集根据训练命令由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py 注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改 ``` vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py 搜索_is_gcs_disabled,修改为_is_gcs_disabled = True ``` 数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下: ``` ── cifar10 │   ├── 3.0.2 │    ├── cifar10-test.tfrecord-00000-of-00001 │    ├── cifar10-train.tfrecord-00000-of-00001 │    ├── dataset_info.json │    ├── features.json │ └── label.labels.txt ── cifar100 │   └── 3.0.2 │    ├── cifar100-test.tfrecord-00000-of-00001 │    ├── cifar100-train.tfrecord-00000-of-00001 │    ├── coarse_label.labels.txt │    ├── dataset_info.json │    ├── features.json │ └── label.labels.txt ``` ### 推理数据集 推理所用图片和文件可根据[scnet](http://113.200.138.88:18080/aidatasets/project-dependency/vision_transformer_jax)或以下代码进行下载: ``` # ./dataset是存储地址,可自订 wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -P ./dataset wget https://picsum.photos/384 -O ./dataset/picsum.jpg # 将图片调整为384分辨率 ``` 数据集目录结构如下: ``` ── dataset │   ├── ilsvrc2012_wordnet_lemmas.txt │ └── picsum.jpg ``` ## 训练 检查点可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax/-/tree/master/imagenet21k?ref_type=heads)或以下方式进行下载: ``` cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订 wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-L_16.npz ``` ### 单机单卡 ``` cd /your_code_path/vision_transformer sh test.sh # workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录 # config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集 # config.pretrained_dir=$(pwd)/test_result # 检查点所在目录 # config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64) # config.total_steps=500 # 微调轮次 # config.warmup_steps=50 # 学习率衰减轮次 # config.batch=512 # 训练批次 # config.pp.crop=384 # 图像块的分辨率 # config.optim_dtype='bfloat16' # 精度 ``` ## 推理 检查点可通过[scnet](http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax/-/tree/master/imagenet21k+imagenet2012?ref_type=heads)或以下方式进行下载: ``` cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订 wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16.npz -O ViT-B_16_imagenet2012.npz wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-L_16.npz -O ViT-L_16_imagenet2012.npz ``` ``` cd /your_code_path/vision_transformer python test.py # 文件内可修改模型目录和数据集目录 ``` ## result 测试图为:
``` ----ViT-B_16: dcu推理结果: 0.73861 : alp 0.24576 : valley, vale 0.00416 : lakeside, lakeshore 0.00404 : cliff, drop, drop-off 0.00094 : promontory, headland, head, foreland 0.00060 : mountain_tent 0.00055 : dam, dike, dyke 0.00033 : volcano 0.00031 : ski 0.00012 : solar_dish, solar_collector, solar_furnace gpu推理结果: 0.73976 : alp 0.24465 : valley, vale 0.00414 : lakeside, lakeshore 0.00404 : cliff, drop, drop-off 0.00094 : promontory, headland, head, foreland 0.00060 : mountain_tent 0.00054 : dam, dike, dyke 0.00033 : volcano 0.00031 : ski 0.00012 : solar_dish, solar_collector, solar_furnace ----ViT-L_16: dcu推理结果: 0.87382 : alp 0.11846 : valley, vale 0.00550 : cliff, drop, drop-off 0.00023 : mountain_tent 0.00017 : promontory, headland, head, foreland 0.00015 : lakeside, lakeshore 0.00013 : dam, dike, dyke 0.00006 : volcano 0.00006 : ski 0.00004 : sandbar, sand_bar gpu推理结果: 0.87399 : alp 0.11828 : valley, vale 0.00550 : cliff, drop, drop-off 0.00023 : mountain_tent 0.00017 : promontory, headland, head, foreland 0.00015 : lakeside, lakeshore 0.00013 : dam, dike, dyke 0.00006 : volcano 0.00006 : ski 0.00004 : sandbar, sand_bar ``` ### 精度 k800*1(1410Mhz,80G,cuda11.8): | 参数 | acc | loss | | -------------------------------- | ------- | -------- | | model_datasets='b16,cifar10'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.98047 | 0.428023 | | model_datasets='b16,cifar100'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.89206 | 1.25078 | | model_datasets='l16,cifar10'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.98890 | 0.348941 | | model_datasets='l16,cifar100'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.91375 | 1.05141 | k100*1(1270Mhz,64G,dtk24.04.1): | 参数 | acc | loss | | ------------------------------------------------------------ | ------- | -------- | | model_datasets='b16,cifar10'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.98037 | 0.43239 | | model_datasets='b16,cifar100'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.89001 | 1.2273 | | model_datasets='l16,cifar10'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.98921 | 0.306221 | | model_datasets='l16,cifar100'
config.batch=512
config.total_steps=500
config.optim_dtype = 'bfloat16' | 0.91447 | 0.976117 | ## 应用场景 ### 算法类别 `图像分类` ### 热点应用行业 `制造,电商,医疗,广媒,教育` ## 预训练权重 - http://113.200.138.88:18080/aimodels/findsource-dependency/vision_transformer_jax - https://console.cloud.google.com/storage/browser/vit_models/imagenet21k/ (微调用) https://console.cloud.google.com/storage/browser/vit_models/imagenet21k+imagenet2012/ (推理用) ## 源码仓库及问题反馈 - https://developer.hpccube.com/codes/modelzoo/vision_transformer_jax ## 参考资料 - https://github.com/google-research/vision_transformer