README.md 3.12 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
<!--
 * @Author: lijian6
 * @email: lijian6@sugon.com
 * @Date: 2023-06-06
 * @LastEditTime: 2023-06-06
6
 * @FilePath: README.md
lijian6's avatar
lijian6 committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
-->
# Vision Transformer(ViT)

## 模型介绍

ViT的是将Transformer模型应用于计算机视觉领域,以替代传统的卷积神经网络(CNN)模型。

## 模型结构

Vision Transformer模型结构如下图所示主要包括三部分,patch embeding 部分、transformer encoder部分、MLP head部分。ViT将输入图片分为多个patch,再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测。

## 数据集

使用ImageNet数据集做pretrain,pretrain之后的模型使用[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)做fine-tuning

lijian6's avatar
lijian6 committed
22
## 训练和推理
lijian6's avatar
lijian6 committed
23
24

### 环境配置
25
26

[光源](https://www.sourcefind.cn/#/service-details)可拉取训练的docker镜像,推荐的镜像如下:
lijian6's avatar
lijian6 committed
27
28
* 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py37-latest

lijian6's avatar
lijian6 committed
29
若想使用MIGraphX做推理,可在[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)中下载MIGraphX并安装.
lijian6's avatar
lijian6 committed
30
31

### Fine-tunning
32
33

模型的训练程序是train.py,预训练模型为[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth),需要先下载预训练模型。fine-tuning训练模型使用以下命令:
lijian6's avatar
lijian6 committed
34
35
36
37
38
39

    python train.py 

Fine-tuning时可调整epoch参数来调整模型。

### 预训练模型
40
41

本项目使用[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth)为预训练模型,需自行下载。下载完预训练模型和数据集,即可开始fine-tun。可用torch进行推理,也可转为onnx模型使用MIGraphX进行推理。
lijian6's avatar
lijian6 committed
42
43

### 推理
lijian6's avatar
lijian6 committed
44
45

推理测试用infer_pytroch.py和infer_migraphx.py对训练出的模型进行推理,使用方法如下:
lijian6's avatar
lijian6 committed
46
47
48
49

    python infer_pytroch.py ./flower_photos/daisy/
    python infer_migraphx.py --imgpath=./flower_photos/daisy/

lijian6's avatar
lijian6 committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
## 准确率数据

测试数据使用的是[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz),使用的加速卡是DCU Z100

| Engine | Model Path| Data | Accuracy(%) |
| :------: | :------: | :------: | :------: |
| Pythorch | models/model.onnx | daisy | 98.4 |
| Pythorch | models/model.onnx | dandelion | 98.2 |
| Pythorch | models/model.onnx | roses | 90.0 |
| Pythorch | models/model.onnx | sunflowers | 97.4 |
| Pythorch | models/model.onnx | tulips | 95.4 |
| MIGraphX | models/model.onnx | daisy | 98.4 |
| MIGraphX | models/model.onnx | dandelion | 98.8 |
| MIGraphX | models/model.onnx | roses | 91.3 |
| MIGraphX | models/model.onnx | sunflowers | 97.4 |
| MIGraphX | models/model.onnx | tulips | 95.0 |

## 源码仓库及问题反馈

https://developer.hpccube.com/codes/modelzoo/vit_migraphx.git