# Vision Transformer(ViT) ## 模型介绍 ViT的是将Transformer模型应用于计算机视觉领域,以替代传统的卷积神经网络(CNN)模型。 ## 模型结构 Vision Transformer模型结构如下图所示主要包括三部分,patch embeding 部分、transformer encoder部分、MLP head部分。ViT将输入图片分为多个patch,再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测。 ## 数据集 使用ImageNet数据集做pretrain,pretrain之后的模型使用[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)做fine-tuning ## 训练和推理 ### 环境配置 在[光源](https://www.sourcefind.cn/#/service-details)可拉取训练的docker镜像,推荐的镜像如下: * 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py37-latest 若想使用MIGraphX做推理,可在[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)中下载MIGraphX并安装. ### Fine-tunning 模型的训练程序是train.py,预训练模型为[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth),需要先下载预训练模型。fine-tuning训练模型使用以下命令: python train.py Fine-tuning时可调整epoch参数来调整模型。 ### 预训练模型 本项目使用[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth)为预训练模型,需自行下载。下载完预训练模型和数据集,即可开始fine-tun。可用torch进行推理,也可转为onnx模型使用MIGraphX进行推理。 ### 推理 推理测试用infer_pytroch.py和infer_migraphx.py对训练出的模型进行推理,使用方法如下: python infer_pytroch.py ./flower_photos/daisy/ python infer_migraphx.py --imgpath=./flower_photos/daisy/ ## 准确率数据 测试数据使用的是[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),使用的加速卡是DCU Z100 | Engine | Model Path| Data | Accuracy(%) | | :------: | :------: | :------: | :------: | | Pythorch | models/model.onnx | daisy | 98.4 | | Pythorch | models/model.onnx | dandelion | 98.2 | | Pythorch | models/model.onnx | roses | 90.0 | | Pythorch | models/model.onnx | sunflowers | 97.4 | | Pythorch | models/model.onnx | tulips | 95.4 | | MIGraphX | models/model.onnx | daisy | 98.4 | | MIGraphX | models/model.onnx | dandelion | 98.8 | | MIGraphX | models/model.onnx | roses | 91.3 | | MIGraphX | models/model.onnx | sunflowers | 97.4 | | MIGraphX | models/model.onnx | tulips | 95.0 | ## 源码仓库及问题反馈 https://developer.hpccube.com/codes/modelzoo/vit_migraphx.git