README.md 3.41 KB
Newer Older
lijian6's avatar
lijian6 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
<!--
 * @Author: lijian6
 * @email: lijian6@sugon.com
 * @Date: 2023-06-06
 * @LastEditTime: 2023-06-06
 * @FilePath: \lpr\README.md
-->
# Vision Transformer(ViT)

## 模型介绍

ViT的是将Transformer模型应用于计算机视觉领域,以替代传统的卷积神经网络(CNN)模型。

## 模型结构

Vision Transformer模型结构如下图所示主要包括三部分,patch embeding 部分、transformer encoder部分、MLP head部分。ViT将输入图片分为多个patch,再将每个patch投影为固定长度的向量送入Transformer,后续encoder的操作和原始Transformer中完全相同。但是因为对图片分类,因此在输入序列中加入一个特殊的token,该token对应的输出即为最后的类别预测。

## 数据集

使用ImageNet数据集做pretrain,pretrain之后的模型使用[flower_photos](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)做fine-tuning

## 训练

### 环境配置
[光源](https://www.sourcefind.cn/#/service-details)可拉取训练以及推理的docker镜像,推荐的镜像如下:
* 训练镜像:docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10.1-py37-latest

[光合开发者社区](https://cancon.hpccube.com:65024/4/main/)可下载MIGraphX:

### Fine-tunning
模型的训练程序是train.py,预训练模型为[base_patch16_224_in21k](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth),需要先下载与训练模型。fine-tuning训练模型使用以下命令:

    python train.py 

Fine-tuning时可调整epoch参数来调整模型。

### 预训练模型
在weights文件夹下我们提供了一个预训练模型以及对应的fine-tuning模型和onnx模型。

### 推理
推理测试模型用infer_pytroch.py和infer_migraphx.py对训练出的模型进行推理,使用方法如下:

    python infer_pytroch.py ./flower_photos/daisy/
    python infer_migraphx.py --imgpath=./flower_photos/daisy/

## 代码使用简介

1. 下载好数据集,代码中默认使用的是花分类数据集,下载地址: [https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz](https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz)
2.`train.py`脚本中将`--data-path`设置成解压后的`flower_photos`文件夹绝对路径
3. 下载预训练权重,在`vit_model.py`文件中每个模型都有提供预训练权重的下载地址,根据自己使用的模型下载对应预训练权重
4.`train.py`脚本中将`--weights`参数设成下载好的预训练权重路径
5. 设置好数据集的路径`--data-path`以及预训练权重的路径`--weights`就能使用`train.py`脚本开始训练了(训练过程中会自动生成`class_indices.json`文件)
6.`predict.py`脚本中导入和训练脚本中同样的模型,并将`model_weight_path`设置成训练好的模型权重路径(默认保存在weights文件夹下)
7.`predict.py`脚本中将`img_path`设置成你自己需要预测的图片绝对路径
8. 设置好权重路径`model_weight_path`和预测的图片路径`img_path`就能使用`predict.py`脚本进行预测了
9. 如果要使用自己的数据集,请按照花分类数据集的文件结构进行摆放(即一个类别对应一个文件夹),并且将训练以及预测脚本中的`num_classes`设置成你自己数据的类别数