# Vision Transformer(ViT) ## 模型介绍 ViT的是将Transformer模型应用于计算机视觉领域,以替代传统的卷积神经网络(CNN)模型。 ## 模型结构 Vision Transformer模型结构如下图所示主要包括三部分,patch embeding 部分、transformer encoder部分、MLP head部分。ViT将输入图片分为多个patch,再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测。 ## 数据集 使用ImageNet数据集做pretrain,pretrain之后的模型使用[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)做fine-tuning ## 训练 ### 环境配置 在[光源](https://www.sourcefind.cn/#/service-details)可拉取训练的docker镜像,推荐的镜像如下: * 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py37-latest 若想使用MIGraphX做推理,可在[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)中下载MIGraphX. ### Fine-tunning 模型的训练程序是train.py,预训练模型为[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth),需要先下载预训练模型。fine-tuning训练模型使用以下命令: python train.py Fine-tuning时可调整epoch参数来调整模型。 ### 预训练模型 本项目使用[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth)为预训练模型,需自行下载。下载完预训练模型和数据集,即可开始fine-tun。可用torch进行推理,也可转为onnx模型使用MIGraphX进行推理。 ### 推理 推理测试模型用infer_pytroch.py和infer_migraphx.py对训练出的模型进行推理,使用方法如下: python infer_pytroch.py ./flower_photos/daisy/ python infer_migraphx.py --imgpath=./flower_photos/daisy/ ## 代码使用简介 1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz) 2. 在`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径 3. 下载预训练权重,在`vit_model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重 4. 在`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径 5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件) 6. 在`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下) 7. 在`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径 8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了 9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数