# ViT ## 论文 `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 ## 模型结构 ViT主要包括patch embeding、transformer encoder、MLP head三部分:以图像块的线性嵌入为输入、添加位置嵌入和可学习的cls_token(patch embeding),并直接应用无decoder的Transformer进行学习。 由于没有归纳偏置,ViT在中小型数据集上性能不如CNN,但当模型和数据量提升时性能会持续提升。
## 算法原理 整个模型结构可以分为五个步骤进行: 1、将图片切分成多个patch。 2、将得到的patches经过一个线性映射层后得到多个token embedding。 3、将得到的多个token embedding concat一个额外的cls_token,然后和位置编码相加,构成完整的encoder模块的输入。 4、 将相加后的结果传入Transformer Encoder模块。 5、Transformer Encoder 模块的输出经过MLP Head 模块做分类输出。
## 环境配置 ### Docker(方法一) ``` docker pull image.sourcefind.cn:5000/dcu/admin/base/jax:0.4.23-ubuntu20.04-dtk24.04.1-py3.10 docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash cd /your_code_path/vision_transformer pip install -r requirements.txt # flax会强制安装某版本ai包 pip uninstall tensorflow pip uninstall jax pip uninstall jaxlib pip install tensorflow-cpu==2.13.1 wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl ``` ### Dockerfile(方法二) ``` cd ./docker docker build --no-cache -t vision_transformer:latest . docker run -it --network=host --privileged=true --name=docker_name --device=/dev/kfd --device=/dev/dri --group-add video --shm-size=32G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v /path/your_code_data/:/path/your_code_data/ imageID /bin/bash cd /your_code_path/vision_transformer pip install -r requirements.txt pip install -r requirements.txt # flax会强制安装某版本ai包 pip uninstall tensorflow pip uninstall jax pip uninstall jaxlib pip install tensorflow-cpu==2.13.1 wget jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl wget jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jax-0.4.23+das1.1.git387bd43.abi1.dtk2404-py3-none-any.whl pip install jaxlib-0.4.23+das1.1.git387bd43.abi1.dtk2404-cp310-cp310-manylinux_2_31_x86_64.whl ``` ### Anaconda(方法三) 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/ ``` DTK软件栈:dtk24.04.1 python:python3.10 jax:0.4.23 ``` `Tips:以上dtk软件栈、python、jax等DCU相关工具版本需要严格一一对应` 2、其他非特殊库直接按照下面步骤进行安装 ``` cd /your_code_path/vision_transformer pip install -r requirements.txt pip install tensorflow-cpu==2.13.1 ``` ## 数据集 `cifar10 cifar100` 数据集由tensorflow_datasets自动下载和处理,相关代码见vision_transformer/vit_jax/input_pipeline.py 注:若发生错误All attempts to get a Google authentication bearer token failed..,按以下代码更改 ``` vim /usr/local/lib/python3.10/site-packages/tensorflow_datasets/core/utils/gcs_utils.py 搜索_is_gcs_disabled,修改为_is_gcs_disabled = True ``` 数据集下载地址及处理设置见./configs/common.py,默认存储地址为/root/tensorflow_datasets/,数据集目录结构如下: ``` ── cifar10 │   ├── 3.0.2 │    ├── cifar10-test.tfrecord-00000-of-00001 │    ├── cifar10-train.tfrecord-00000-of-00001 │    ├── dataset_info.json │    ├── features.json │ └── label.labels.txt ── cifar100 │   └── 3.0.2 │    ├── cifar100-test.tfrecord-00000-of-00001 │    ├── cifar100-train.tfrecord-00000-of-00001 │    ├── coarse_label.labels.txt │    ├── dataset_info.json │    ├── features.json │ └── label.labels.txt ``` ## 训练 检查点可通过以下方式进行下载: ``` cd /your_code_path/vision_transformer/test_result # test_result为检查点下载地址,可自订 wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz ``` ### 单机单卡 ``` cd /your_code_path/vision_transformer sh test.sh # workdir=$(pwd)/test_result/dcu/vit-$(date +%s) # 指定存储日志和模型数据的目录 # config=$(pwd)/vit_jax/configs/vit.py:$model_datasets # 指定用于微调的模型/数据集 # config.pretrained_dir=$(pwd)/test_result # 检查点所在目录 # config.accum_steps=64 # 累加梯度的轮次(tpu=8,cpu=64) # config.total_steps=500 # 微调轮次 # config.warmup_steps=50 # 学习率衰减轮次 # config.batch=512 # 训练批次 # config.pp.crop=384 # 图像块的分辨率 # config.optim_dtype='bfloat16' # 精度 ``` ## 推理 ``` python test.py ``` ## result 此处填算法效果测试图(包括输入、输出)
### 精度 测试数据:[test data](链接),使用的加速卡:xxx。 根据测试结果情况填写表格: | xxx | xxx | xxx | xxx | xxx | | :------: | :------: | :------: | :------: |:------: | | xxx | xxx | xxx | xxx | xxx | | xxx | xx | xxx | xxx | xxx | ## 应用场景 ### 算法类别 `图像识别` ### 热点应用行业 `制造,电商,医疗,广媒,教育` ## 源码仓库及问题反馈 - https://developer.hpccube.com/codes/modelzoo/vision_transformer_jax ## 参考资料 - https://github.com/FoundationVision/VAR