Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
vision-transformers-cifar10_pytorch
Commits
fa667fea
"tests/vscode:/vscode.git/clone" did not exist on "531e719163d2d7cf0d725bb685c1e8fe3393b9da"
Commit
fa667fea
authored
Jun 17, 2023
by
Sugon_ldc
Browse files
add readme
parent
459ecd48
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
61 deletions
+76
-61
README.md
README.md
+68
-61
model.properties
model.properties
+8
-0
No files found.
README.md
View file @
fa667fea
# vision-transformers-cifar10
# Vision-Transformers-cifar10_PyTorch
Let's train vision transformers for cifar 10!
## 模型介绍
This is an unofficial and elementary implementation of
`An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
.
Vision Transformers (ViT) 是一种基于 Transformer 的视觉模型,它在图像分类任务上表现出了与当前最先进的卷积神经网络相当的性能。而ViT模型是使用Transformer架构来代替传统卷积神经网络的。ViT模型的核心思想是将输入的图像分割成数个图块(patch),然后将这些图块拼接成一个序列,再输入到Transformer中进行特征提取和分类。
I use pytorch for implementation.
## 模型结构
### Updates
*
Added
[
ConvMixer
](
(https://openreview.net/forum?id=TVHS5Y4dNvM
)
) implementation. Really simple! (2021/10)
Vision Transformers (ViT) - CIFAR10 模型的结构基于 Transformer 的视觉模型,其核心思想是将输入的图像分割成数个图块(patch),然后将这些图块拼接成一个序列,再输入到 Transformer 中进行特征提取和分类。
*
Added wandb train log to reproduce results. (2022/3)
具体来说,该模型的输入是一张 $32
\t
imes32$ 的彩色图像,其中每个图像被分割成 $4
\t
imes4$ 个大小为 $8
\t
imes8$ 的图块。这些图块被重塑为一个序列,并传递给 Transformer 编码器进行特征提取。在编码器中,每个序列元素(即每个图块)都经过一个多头自注意力机制(Multi-Head Self-Attention)模块和一个前馈网络模块(Feed-Forward Network)进行处理。这些模块的输出被送入下一层编码器,直到最终输出分类结果。
*
Added CaiT and ViT-small. (2022/3)
该模型的主要组成部分是由多个 Transformer 编码器层组成的。每个编码器层由一个多头自注意力机制模块和一个前馈网络模块组成。这些编码器层之间使用残差连接(Residual Connection)连接,以便有效地传播信息和减轻梯度消失问题。
*
Added SwinTransformers. (2022/3)
最后,模型的输出通过一个全连接层进行分类,该层将 Transformer 编码器的输出展平为一维张量,并对其进行分类。整个模型的训练过程使用交叉熵损失函数进行优化。
*
Added MLP mixer. (2022/6)
## 数据集
*
Changed default training settings for ViT.
本次使用的数据集为
[
CIFAR-10
](
https://www.cs.toronto.edu/~kriz/cifar.html
)
,CIFAR-10数据集由10类中的60000 32x32颜色图像组成,每个类别有6000张图像。有50000个训练图像和10000个测试图像。
# Usage example
该数据集分为五个训练批次和一批测试批次,每个测试批次有10000张图像。该测试批次完全包含来自每个类的1000个随机选择的图像。训练批次包含剩余的图像,但有些训练批次可能包含一个班级的图像,而不是另一个班级的图像。在他们之间,训练批次包含每个班级的5000张图像。
`python train_cifar10.py`
# vit-patchsize-4
## 训练及推理
`python train_cifar10.py --size 48`
# vit-patchsize-4-imsize-48
### 环境配置
`python train_cifar10.py --patch 2`
# vit-patchsize-2
提供光源拉取的训练镜像
`python train_cifar10.py --net vit_small --n_epochs 400`
# vit-small
```
`python train_cifar10.py --net vit_timm`
# train with pretrained vit
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py38-latest
```
`python train_cifar10.py --net convmixer --n_epochs 400`
# train with convmixer
安装依赖:
`python train_cifar10.py --net mlpmixer --n_epochs 500 --aug --lr 1e-3`
```
`python train_cifar10.py --net cait --n_epochs 200`
# train with cait
pip install pandas==1.5.3
```
`python train_cifar10.py --net swin --n_epochs 400`
# train with SwinTransformers
创建软链接
`python train_cifar10.py --net res18`
# resnet18+randaug
```
# Results..
ln -s datasets/cifar10 data
```
| | Accuracy | Train Log |
|:-----------:|:--------:|:--------:|
若需要指定路径,也可以在工程根目录的train_cifar10.py中修改(修改130行左右的trainset和testset变量即可)
| ViT patch=2 | 80% | |
| ViT patch=4 Epoch@200 | 80% |
[
Log
](
https://wandb.ai/arutema47/cifar10-challange/reports/Untitled-Report--VmlldzoxNjU3MTU2?accessToken=3y3ib62e8b9ed2m2zb22dze8955fwuhljl5l4po1d5a3u9b7yzek1tz7a0d4i57r
)
|
开始训练
| ViT patch=4 Epoch@500 | 88% |
[
Log
](
https://wandb.ai/arutema47/cifar10-challange/reports/Untitled-Report--VmlldzoxNjU3MTU2?accessToken=3y3ib62e8b9ed2m2zb22dze8955fwuhljl5l4po1d5a3u9b7yzek1tz7a0d4i57r
)
|
| ViT patch=8 | 30% | |
```
| ViT small | 80% | |
bash train.sh
| MLP mixer | 88% | |
#默认使用四卡训练,可在train.sh脚本中修改
| CaiT | 80% | |
```
| Swin-t | 90% | |
| ViT small (timm transfer) | 97.5% | |
## 准确率数据
| ViT base (timm transfer) | 98.5% | |
|
[
ConvMixerTiny(no pretrain)
](
https://openreview.net/forum?id=TVHS5Y4dNvM
)
| 96.3% |
[
Log
](
https://wandb.ai/arutema47/cifar10-challange/reports/convmixer--VmlldzoyMjEyOTk1?accessToken=2w9nox10so11ixf7t0imdhxq1rf1ftgzyax4r9h896iekm2byfifz3b7hkv3klrt
)
|
| 卡数 | 准确率 |
| resnet18 | 93% | |
| :--: | :----: |
| resnet18+randaug | 95% |
[
Log
](
https://wandb.ai/arutema47/cifar10-challange/reports/Untitled-Report--VmlldzoxNjU3MTYz?accessToken=968duvoqt6xq7ep75ob0yppkzbxd0q03gxy2apytryv04a84xvj8ysdfvdaakij2
)
|
| 4 | 84.91% |
# Used in..
*
Vision Transformer Pruning
[
arxiv
](
https://arxiv.org/abs/2104.08500
)
[
github
]
(https://github.com/Cydia2018/ViT-cifar10-pruning)
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/vision-transformers-cifar10_pytorch
## 参考
https://github.com/kentaroy47/vision-transformers-cifar10
\ No newline at end of file
model.properties
0 → 100644
View file @
fa667fea
# 模型名称
modelName
=
Vision-Transformers-cifar10_PyTorch
# 模型描述
modelDescription
=
Vision Transformers (ViT) 是一种基于 Transformer 的视觉模型
# 应用场景(多个标签以英文逗号分割)
appScenario
=
训练,pytorch,transformer
# 框架类型(多个标签以英文逗号分割)
frameType
=
PyTorch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment