README.md 4.04 KB
Newer Older
huaerkl's avatar
v1.0  
huaerkl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# DATA2VEC
## 论文
`Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language`
- https://arxiv.org/abs/2212.07525
## 模型结构
data2vec的主要思路就是先建立一个教师网络,首先计算来自图像、文本或语音的目标表征。然后对数据进行掩码遮盖掉部分输入,并用一个学生网络重复该过程预测教师模型得到的表征。也就是说,学生模型只能在接受「不完整输入信息」的同时预测「完整输入数据」的表示,为了保证两个模型的一致性,二者的参数是共享的,但在训练初期会让Teacher模型的参数更新更快。

![img](./images/datavec.png)
## 算法原理
Meta AI发布的data2cec 2.0版本,在性能方面对上一代进行了改进:在精度相同的情况下,训练速度相比其他算法最高提升了16倍!主要解决的痛点是构建自监督模型需要大量的GPU做算力支撑才能完成训练。

data2vec 2.0通过以下三种方式提高原始 data2vec 算法的效率:
- 为特定训练样例构建目标表征,并将该表征重用在掩码版本上。
- 类似于掩码自编码器(masked autoencoder, MAE),学生模型中的编码器网络并不运训练样例中的空白部分(blanked out)。
- 使用了一个更有效的解码器模型,不再依赖于Transformer网络,而是依赖于一个多层卷积网络。

![img](./images/speedup.png)
## 环境配置

1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
https://developer.hpccube.com/tool/
```
DTK驱动:dtk23.04
python:python3.8
torch:torch1.13
torchvision:1.13
torchaudio:1.13.1
apex:1.13
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应`

2、编译安装fairseq库

`Version: 0.12.2`
```
cd fairseq
pip install --editable ./ #或python setup.py build_ext --inplace
```
3、其它非特殊库参照requirements.txt安装
```
pip install -r requirements.txt
```
## 数据集

ILSVRC 2012:
https://image-net.org/challenges/LSVRC/index.php

`imagenet 2012` 的解压与整理方法参照链接:
https://www.jianshu.com/p/a42b7d863825

整理完成后的数据目录结构如下:
```
data
    |
    train
        |
        n01440764
        n01806143
        ...
    val
        |
        n04286575
        n04596742
        ...
    test
        |
        images
            |
            test_x.JPEG
            test_xxx.JPEG
            ...
```
## 训练
进入主目录:
```
cd fairseq && mkdir logs
```
### 一、mpirun训练:
**多机多卡:**
```
sbatch datavec_mpi.sh
```
**单机多卡**(需先单独申请线上节点):
需先修改[`arguments.py`](./fairseq/distributed/utils.py)中的`call_main`函数:

`from`
```
#srun need auto rank: distributed_main_autorank
torch.multiprocessing.spawn(
    fn=distributed_main_autorank,
    args=(main, cfg, kwargs),
    nprocs=1,
    join=True,
)
'''
torch.multiprocessing.spawn(
    fn=distributed_main,
    args=(main, cfg, kwargs),
    nprocs=min(
        torch.cuda.device_count(),
        cfg.distributed_training.distributed_world_size,
    ),
    join=True,
)
'''
```
`to`
```
'''
#srun need auto rank: distributed_main_autorank
torch.multiprocessing.spawn(
    fn=distributed_main_autorank,
    args=(main, cfg, kwargs),
    nprocs=1,
    join=True,
)
'''
torch.multiprocessing.spawn(
    fn=distributed_main,
    args=(main, cfg, kwargs),
    nprocs=min(
        torch.cuda.device_count(),
        cfg.distributed_training.distributed_world_size,
    ),
    join=True,
)
```
然后运行:
```
sh datavec.sh
```
### 二、srun训练
**多机多卡:**
```
sbatch datavec_srun.sh
```
## 推理
参照fairseq源文档[`README.md`](./examples/data2vec/README.md),推理可以不用多节点,单节点算力即可。
## result
![img](./images/classify.png)
![img](./images/accuracy.png)
## 应用场景
### 算法类别
`图像分类`
### 应用行业
`制造,环境,医疗,气象`
### 算法框架
`pytorch`
## 参考资料
- https://github.com/facebookresearch/fairseq