# bert4torch ## 模型介绍 bert4torch是一个基于pytorch的训练框架,前期以效仿和实现bert4keras的主要功能为主,方便加载多类预训练模型进行finetune,提供了中文注释方便用户理解模型结构。 ## 模型结构 BERT的主模型是BERT中最重要组件,BERT通过预训练(pre-training),具体来说,就是在主模型后再接个专门的模块计算预训练的损失(loss),预训练后就得到了主模型的参数(parameter),当应用到下游任务时,就在主模型后接个跟下游任务配套的模块,然后主模型赋上预训练的参数,下游任务模块随机初始化,然后微调(fine-tuning)就可以了(注意:微调的时候,主模型和下游任务模块两部分的参数一般都要调整,也可以冻结一部分,调整另一部分)。 主模型由三部分构成:**嵌入层**、**编码器**、**池化层**。 如图: ![img](https://images.cnblogs.com/cnblogs_com/wangzb96/1789835/o_200618140451BERT%E4%B9%8B%E4%B8%BB%E6%A8%A1%E5%9E%8B.png) 其中 - 输入:一个个小批(mini-batch),小批里是`batch_size`个序列(句子或句子对),每个序列由若干个离散编码向量组成。 - 嵌入层:将输入的序列转换成连续分布式表示(distributed representation),即词嵌入(word embedding)或词向量(word vector)。 - 编码器:对每个序列进行非线性表示。 - 池化层:取出`[CLS]`标记(token)的表示(representation)作为整个序列的表示。 - 输出:编码器最后一层输出的表示(序列中每个标记的表示)和池化层输出的表示(序列整体的表示)。 ## 环境配置 ### Docker 在光源可拉取docker镜像,拉取方式如下: ``` docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py37-latest ``` 安装依赖包和bert4torch ``` pip install -r requirements.txt cd bert4torch python3 setup.py install ``` ## 数据集和预训练模型 数据集下载地址:https://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz 人民日报数据集存放在目录/datasets/bert-base-chinese目录下,然后解压。 预训练模型下载地址:https://huggingface.co/bert-base-chinese/tree/main 所有文件下载存放在目录/datasets/bert-base-chinese下。 训练数据目录结构如下: ``` dataset | bert-base-chinese | china-people-daily-ner-corpus config.json flax_model.msgpack pytorch_model.bin vocab.txt | example.dev example.test example.train ``` ## 训练 ### 修改配置文件 ``` cd examples/sequence_labeling/ # 修改训练脚本配置文件 crf.py # 单卡训练脚本 crf_ddp.py # 多卡训练脚本 多卡训练使用torch的ddp,在单卡训练代码基础上增加DDP的相关内容 仅修改配置文件路径,包括config_path, checkpoint_path, dict_path, train_dataloader, valid_dtaloader,根据需要调整batch_size大小。 注:如果需要测试fp16,可以修改crf_ddp.py和crf.py中model.compile(),添加use_amp=True。 ``` ### 单机单卡 ``` cd examples/sequence_labeling/ ./single_train.sh ``` ### 单机多卡 ``` cd examples/sequence_labeling/ ./multi_train.sh ``` ## 精度数据 | 卡数 | 类型 | batch_size | f1 | p | r | | ---- | ---- | ---------- | ------ | ------ | ------ | | 1 | fp32 | 64 | 0.9592 | 0.9643 | 0.9617 | | 1 | fp16 | 64 | 0.9559 | 0.9596 | 0.9545 | | 4 | fp32 | 256 | 0.9459 | 0.9398 | 0.9521 | | 4 | fp16 | 256 | 0.9438 | 0.9398 | 0.9505 | ## 源码仓库及问题反馈 - https://developer.hpccube.com/codes/modelzoo/bert4torch ## 参考资料 - https://github.com/Tongjilibo/bert4torch