# DATA2VEC ## 论文 `Efficient Self-supervised Learning with Contextualized Target Representations for Vision, Speech and Language`: - https://arxiv.org/abs/2212.07525 ## 模型结构 data2vec的主要思路就是先建立一个教师网络,首先计算来自图像、文本或语音的目标表征。然后对数据进行掩码遮盖掉部分输入,并用一个学生网络重复该过程预测教师模型得到的表征。也就是说,学生模型只能在接受「不完整输入信息」的同时预测「完整输入数据」的表示,为了保证两个模型的一致性,二者的参数是共享的,但在训练初期会让Teacher模型的参数更新更快。 ![img](./images/datavec.png) ## 算法原理 Meta AI发布的data2cec 2.0版本,在性能方面对上一代进行了改进:在精度相同的情况下,训练速度相比其他算法最高提升了16倍!主要解决的痛点是构建自监督模型需要大量的GPU做算力支撑才能完成训练。 data2vec 2.0通过以下三种方式提高原始 data2vec 算法的效率: - 为特定训练样例构建目标表征,并将该表征重用在掩码版本上。 - 类似于掩码自编码器(masked autoencoder, MAE),学生模型中的编码器网络并不运训练样例中的空白部分(blanked out)。 - 使用了一个更有效的解码器模型,不再依赖于Transformer网络,而是依赖于一个多层卷积网络。 ![img](./images/speedup.png) ## 环境配置 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/ ``` DTK驱动:dtk23.04 python:python3.8 torch:torch1.13 torchvision:1.13 torchaudio:1.13.1 apex:1.13 ``` `Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应` 2、编译安装fairseq库 `Version: 0.12.2` ``` cd fairseq pip install --editable ./ #或python setup.py build_ext --inplace ``` 3、其它非特殊库参照requirements.txt安装 ``` pip install -r requirements.txt ``` ## 数据集 ILSVRC 2012: https://image-net.org/challenges/LSVRC/index.php `imagenet 2012` 的解压与整理方法参照链接: https://www.jianshu.com/p/a42b7d863825 整理完成后的数据目录结构如下: ``` data | train | n01440764 n01806143 ... val | n04286575 n04596742 ... test | images | test_x.JPEG test_xxx.JPEG ... ``` ## 训练 进入主目录: ``` cd fairseq && mkdir logs ``` ### 一、mpirun训练: **多机多卡:** ``` sbatch datavec_mpi.sh ``` **单机多卡**(需先单独申请线上节点): 需先修改[`arguments.py`](./fairseq/distributed/utils.py)中的`call_main`函数: `from` ``` #srun need auto rank: distributed_main_autorank torch.multiprocessing.spawn( fn=distributed_main_autorank, args=(main, cfg, kwargs), nprocs=1, join=True, ) ''' torch.multiprocessing.spawn( fn=distributed_main, args=(main, cfg, kwargs), nprocs=min( torch.cuda.device_count(), cfg.distributed_training.distributed_world_size, ), join=True, ) ''' ``` `to` ``` ''' #srun need auto rank: distributed_main_autorank torch.multiprocessing.spawn( fn=distributed_main_autorank, args=(main, cfg, kwargs), nprocs=1, join=True, ) ''' torch.multiprocessing.spawn( fn=distributed_main, args=(main, cfg, kwargs), nprocs=min( torch.cuda.device_count(), cfg.distributed_training.distributed_world_size, ), join=True, ) ``` 然后运行: ``` sh datavec.sh ``` ### 二、srun训练 **多机多卡:** ``` sbatch datavec_srun.sh ``` ## 推理 参照fairseq源文档[`README.md`](./examples/data2vec/README.md),推理可以不用多节点,单节点算力即可。 ## result ![img](./images/classify.png) ![img](./images/accuracy.png) ## 应用场景 ### 算法类别 `图像分类` ### 应用行业 `制造,环境,医疗,气象` ### 算法框架 `pytorch` ## 参考资料 - https://github.com/facebookresearch/fairseq