Commit 0e29b9b7 authored by xuxo's avatar xuxo
Browse files

yidong infer init

parents
Pipeline #3252 failed with stages
in 0 seconds
# 忽略指定扩展名的文件(递归应用于所有子目录)
*.docx
*.tar
*.onnx
*.pt
*.mxr
*.bin
*.log
*.swp
# yidong-infer
提问时请尽可能提供如下信息:
### 基本信息
- 你使用的**操作系统**:
- 你使用的**Python**版本:
- 你使用的**Pytorch**版本:
- 你使用的**bert4torch**版本:
- 你加载的**预训练模型**:
### 核心代码
```python
# 请在此处贴上你的核心代码
```
### 输出信息
```shell
# 请在此处贴上你的调试输出
```
### 自我尝试
此处请贴上你的自我尝试过程
__pycache__
datasets/
*.pt
*.onnx
*.csv
*.json
*.log
bert4torch_test.ipynb
summary/
.idea
.vscode/launch.json
.pypirc
bert4torch.egg-info/
build/
dist/
.DS_Store
bert4torch_test.py
\ No newline at end of file
MIT License
Copyright (c) 2022 Bo仔很忙
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# bert4torch
**一款用pytorch来复现bert4keras的简洁训练框架**
[![licence](https://img.shields.io/github/license/Tongjilibo/bert4torch.svg?maxAge=3600)](https://github.com/Tongjilibo/bert4torch/blob/master/LICENSE)
[![GitHub release](https://img.shields.io/github/release/Tongjilibo/bert4torch.svg?maxAge=3600)](https://github.com/Tongjilibo/bert4torch/releases)
[![PyPI](https://img.shields.io/pypi/v/bert4torch?label=pypi%20package)](https://pypi.org/project/bert4torch/)
[![PyPI - Downloads](https://img.shields.io/pypi/dm/bert4torch)](https://pypistats.org/packages/bert4torch)
[![GitHub stars](https://img.shields.io/github/stars/Tongjilibo/bert4torch?style=social)](https://github.com/Tongjilibo/bert4torch)
[![GitHub Issues](https://img.shields.io/github/issues/Tongjilibo/bert4torch.svg)](https://github.com/Tongjilibo/bert4torch/issues)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/Tongjilibo/bert4torch/issues)
## 1. 下载安装
安装稳定版
```shell
pip install bert4torch
```
安装最新版
```shell
pip install git+https://www.github.com/Tongjilibo/bert4torch.git
```
- **注意事项**:pip包的发布慢于git上的开发版本,git clone**注意引用路径**,注意权重是否需要转换
- **测试用例**`git clone https://github.com/Tongjilibo/bert4torch`,修改example中的预训练模型文件路径和数据路径即可启动脚本
- **自行训练**:针对自己的数据,修改相应的数据处理代码块
- **开发环境**:使用`torch==1.10`版本进行开发,如其他版本遇到不适配,欢迎反馈
## 2. 功能
- **核心功能**:加载bert、roberta、albert、xlnet、nezha、bart、RoFormer、RoFormer_V2、ELECTRA、GPT、GPT2、T5、GAU-alpha、ERNIE等预训练权重继续进行finetune、并支持在bert基础上灵活定义自己模型
- **丰富示例**:包含[pretrain](https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain)[sentence_classfication](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication)[sentence_embedding](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_embedding)[sequence_labeling](https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling)[relation_extraction](https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction)[seq2seq](https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq)[serving](https://github.com/Tongjilibo/bert4torch/blob/master/examples/serving/)等多种解决方案
- **实验验证**:已在公开数据集[实验验证](https://github.com/Tongjilibo/bert4torch/blob/master/examples/Performance.md), 使用如下[examples数据集](https://github.com/Tongjilibo/bert4torch/blob/master/examples/README.md)
- **易用trick**:集成了常见的[trick](https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick), 即插即用
- **其他特性**[加载transformers库模型](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_load_transformers_model.py)一起使用;调用方式简洁高效;有训练进度条动态展示;配合torchinfo打印参数量;默认Logger和Tensorboard简便记录训练过程;自定义fit过程,满足高阶需求
## 3. 快速上手
- [快速上手教程](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials/Tutorials.md), [教程示例](https://github.com/Tongjilibo/bert4torch/blob/master/examples/tutorials), [实战示例](https://github.com/Tongjilibo/bert4torch/blob/master/examples)
- [bert4torch介绍(知乎)](https://zhuanlan.zhihu.com/p/486329434)[bert4torch快速上手(知乎)](https://zhuanlan.zhihu.com/p/508890807), [bert4torch又双叒叕更新啦(知乎)](https://zhuanlan.zhihu.com/p/560885427?)
- 背景:用pytorch复现苏神的[bert4keras](https://github.com/bojone/bert4keras),初版参考了[bert4pytorch](https://github.com/MuQiuJun-AI/bert4pytorch)
## 4. 版本说明
- **v0.2.1**:兼容torch<=1.7.1的torch.div无rounding_mode, 增加自定义metrics,支持断点续训,增加默认Logger和Tensorboard日志
- **v0.2.0**:兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换,打印Epoch开始的时间戳,增加parallel_apply
- **v0.1.9**:增加mixup/manifold_mixup/temporal_ensembling策略, 修复pgd策略param.grad为空的问题,修改tokenizer支持批量
- **v0.1.8**:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
- **v0.1.7**:增加EarlyStop,CRF中自带转bool类型
- **v0.1.6**:增加transformer_xl、xlnet、t5_pegasus模型,prompt、预训练等示例,支持增加embedding输入,EMA策略,修复tokenizer和sinusoid的bug
- **v0.1.5**:增加GAU-alpha,混合梯度,梯度裁剪,单机多卡(DP、DDP)
- **v0.1.4**:增加了VAT,修复了linux下apply_embedding返回项有问题的情况
- **v0.1.3**:初始版本
## 5. 更新:
- **2022年9月20更新**:增加TensorRT示例,支持多个schedule(如同时ema+warmup),sanic+onnx部署
- **2022年9月10更新**:增加默认Logger和Tensorboard日志,ONNX推理,增加ERNIE模型, 修复t5的norm_mode问题,允许hidden_size不整除num_attention_heads
- **2022年8月28更新**:增加nl2sql示例, 增加自定义metrics,支持断点续训
- **2022年8月21更新**:增加W2NER和DiffCSE示例,打印Epoch开始的时间戳,增加parallel_apply, 兼容torch<=1.7.1的torch.div无rounding_mode
- **2022年8月14更新**:增加有监督句向量、关系抽取、文本生成实验指标,兼容torch<1.9.0的缺失take_along_dim,修复bart中位置向量514的问题,修复Sptokenizer对符号不转换
- **2022年7月27更新**:增加mixup/manifold_mixup/temporal_ensembling策略, 修复pgd策略param.grad为空的问题,修改tokenizer支持批量,增加uie示例
- **2022年7月16更新**:修复原来CRF训练中loss陡增的问题,修复xlnet的token_type_ids输入显存占用大的问题
- **2022年7月10更新**:增加金融中文FAQ示例,天池新闻分类top1案例,增加EarlyStop,CRF中自带转bool类型
- **2022年6月29更新**:增加ner的实验,测试crf不同初始化的效果,bert-whitening中文实验
- **2022年6月13更新**:增加seq2seq+前缀树,增加SimCSE/ESimCSE/PromptBert等无监督语义相似度的中文实验
- **2022年6月05更新**:增加PromptBert、PET、P-tuning示例,修改tokenizer对special_tokens分词错误的问题,增加t5_pegasus
- **2022年5月29更新**:transformer_xl、xlnet模型, 修改sinusoid位置向量被init_weight的bug, EMA,sohu情感分类示例
- **2022年5月17更新**:增加预训练代码,支持增加embedding输入(如词性,word粒度embedding)
- **2022年5月01更新**:增加了混合梯度,梯度裁剪,单机多卡训练(DP、DDP)
- **2022年4月25更新**:增加了VAT、GAU-alpha等示例,增加了梯度累积,自定义fit()示例
- **2022年4月15更新**:增加了ner_mrc、ner_span、roformer_v2、roformer-sim等示例
- **2022年4月05更新**:增加了GPLinker、TPlinker、SimBERT等示例
- **2022年3月29更新**:增加了CoSENT、R-Drop、UDA等示例
- **2022年3月22更新**:添加GPT、GPT2、T5模型
- **2022年3月12更新**:初版提交
## 6. 预训练权重
- 部分权重是要加载修改的[config.json](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/PLM_config.md)
| 模型分类 | 权重来源 | 权重链接 | 备注(若有) |
| ---- | ---- | ---- | ---- |
| bert | 谷歌原版bert(即bert-base-chinese) | [tf](https://github.com/google-research/bert), [pytorch](https://huggingface.co/bert-base-chinese) | [tf转pytorch命令](https://huggingface.co/docs/transformers/converting_tensorflow_models), [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_bert-base-chinese.py)
| bert | 哈工大chinese-bert-wwm-ext | [tf/pytorch](https://github.com/ymcui/Chinese-BERT-wwm), [pytorch](https://huggingface.co/hfl/chinese-bert-wwm-ext) |
| robert | 哈工大chinese-robert-wwm-ext | [tf/pytorch](https://github.com/ymcui/Chinese-BERT-wwm), [pytorch](https://huggingface.co/hfl/chinese-roberta-wwm-ext)
| xlnet | 哈工大xlnet | [tf/pytorch](https://github.com/ymcui/Chinese-XLNet) | [config](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/PLM_config.md)
| electra | 哈工大electra | [tf](https://github.com/ymcui/Chinese-ELECTRA), [pytorch](https://huggingface.co/hfl/chinese-electra-base-discriminator)
| macbert | 哈工大macbert | [tf](https://github.com/ymcui/MacBERT), [pytorch](https://huggingface.co/hfl/chinese-macbert-base)
| albert | brightmart | [tf](https://github.com/brightmart/albert_zh), [pytorch](https://github.com/lonePatient/albert_pytorch)
| ernie | 百度文心 |[paddle](https://github.com/PaddlePaddle/ERNIE), [pytorch](https://huggingface.co/nghuyong) |
| roformer | 追一科技 | [tf](https://github.com/ZhuiyiTechnology/roformer), [pytorch](https://huggingface.co/junnyu/roformer_chinese_base) |
| roformer_v2 | 追一科技 | [tf](https://github.com/ZhuiyiTechnology/roformer-v2), [pytorch](https://huggingface.co/junnyu/roformer_v2_chinese_char_base) |
| simbert | 追一科技 | [tf](https://github.com/ZhuiyiTechnology/simbert), [pytorch](https://huggingface.co/peterchou/simbert-chinese-base/tree/main) |
| roformer-sim | 追一科技 | [ft](https://github.com/ZhuiyiTechnology/roformer-sim), [pytorch](https://huggingface.co/junnyu/roformer_chinese_sim_char_base) |
| gau-alpha | 追一科技 | [tf](https://github.com/ZhuiyiTechnology/GAU-alpha) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_GAU_alpha.py)
| nezha | 华为 | [tf](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-TensorFlow), [pytorch](https://github.com/lonePatient/NeZha_Chinese_PyTorch) |
| gpt | CDial-GPT | [pytorch](https://github.com/thu-coai/CDial-GPT) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_gpt__CDial-GPT-LCCC.py)
| gpt2 | 清华26亿 cmp_lm | [pytorch](https://github.com/TsinghuaAI/CPM-1-Generate) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_gpt2__cmp_lm_2.6b.py)
| gpt2 | 中文GPT2_ML模型 | [tf](https://github.com/imcaspar/gpt2-ml), [pytorch](https://github.com/ghosthamlet/gpt2-ml-torch) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_gpt2__gpt2-ml.py)
| t5 | UER | [pytorch](https://huggingface.co/uer/t5-base-chinese-cluecorpussmall) | [config](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/PLM_config.md)
| mt5 | 谷歌 | [pytorch](https://huggingface.co/google/mt5-base) | [config](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/PLM_config.md)
| t5_pegasus | 追一科技 | [tf](https://github.com/ZhuiyiTechnology/t5-pegasus) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_t5_pegasus.py)
| bart | 复旦 | [pytorch](https://github.com/fastnlp/CPT) | [转换脚本](https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_bart_fudanNLP.py)
#! -*- coding: utf-8 -*-
__version__ = '0.2.2'
\ No newline at end of file
# 从transformer中移植过来的activation, 原来的bert4keras并没有
import math
import torch
from packaging import version
from torch import nn
def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = nn.functional.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = nn.functional.silu
def _mish_python(x):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x))
if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def linear_act(x):
return x
ACT2FN = {
"relu": nn.functional.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
"softmax": nn.Softmax(dim=-1)
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
This diff is collapsed.
from ast import arg
from tracemalloc import start
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class FocalLoss(nn.Module):
'''Multi-class Focal loss implementation'''
def __init__(self, gamma=2, weight=None,ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index=ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1-pt)**self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
return loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean',ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum':
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean':
loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
ignore_index=self.ignore_index)
class MultilabelCategoricalCrossentropy(nn.Module):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1, 1表示对应的类为目标类,0表示对应的类为非目标类。
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax!预测
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解本文。
参考:https://kexue.fm/archives/7359
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, y_pred, y_true):
""" y_true ([Tensor]): [..., num_classes]
y_pred ([Tensor]): [..., num_classes]
"""
y_pred = (1-2*y_true) * y_pred
y_pred_pos = y_pred - (1-y_true) * 1e12
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = torch.cat([y_pred_pos, torch.zeros_like(y_pred_pos[..., :1])], dim=-1)
y_pred_neg = torch.cat([y_pred_neg, torch.zeros_like(y_pred_neg[..., :1])], dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
return (pos_loss + neg_loss).mean()
class SparseMultilabelCategoricalCrossentropy(nn.Module):
"""稀疏版多标签分类的交叉熵
说明:
1. y_true.shape=[..., num_positive],
y_pred.shape=[..., num_classes];
2. 请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax;
3. 预测阶段则输出y_pred大于0的类;
4. 详情请看:https://kexue.fm/archives/7359 。
"""
def __init__(self, mask_zero=False, epsilon=1e-7, **kwargs):
super().__init__(**kwargs)
self.mask_zero = mask_zero
self.epsilon = epsilon
def forward(self, y_pred, y_true):
zeros = torch.zeros_like(y_pred[..., :1])
y_pred = torch.cat([y_pred, zeros], dim=-1)
if self.mask_zero:
infs = zeros + float('inf')
y_pred = torch.cat([infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
y_pos_1 = torch.cat([y_pos_2, zeros], dim=-1)
if self.mask_zero:
y_pred = torch.cat([-infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
pos_loss = torch.logsumexp(-y_pos_1, dim=-1)
all_loss = torch.logsumexp(y_pred, dim=-1) # a
aux_loss = torch.logsumexp(y_pos_2, dim=-1) - all_loss # b-a
aux_loss = torch.clamp(1 - torch.exp(aux_loss), self.epsilon, 1) # 1-exp(b-a)
neg_loss = all_loss + torch.log(aux_loss) # a + log[1-exp(b-a)]
return pos_loss + neg_loss
class ContrastiveLoss(nn.Module):
"""对比损失:减小正例之间的距离,增大正例和反例之间的距离
公式:labels * distance_matrix.pow(2) + (1-labels)*F.relu(margin-distance_matrix).pow(2)
https://www.sbert.net/docs/package_reference/losses.html
"""
def __init__(self, margin=0.5, size_average=True, online=False):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.size_average = size_average
self.online = online
def forward(self, distances, labels, pos_id=1, neg_id=0):
if not self.online:
losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
return losses.mean() if self.size_average else losses.sum()
else:
negs = distances[labels == neg_id]
poss = distances[labels == pos_id]
# select hard positive and hard negative pairs
negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]
positive_loss = positive_pairs.pow(2).sum()
negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
return positive_loss + negative_loss
class RDropLoss(nn.Module):
'''R-Drop的Loss实现,官方项目:https://github.com/dropreg/R-Drop
'''
def __init__(self, alpha=4, rank='adjacent'):
super().__init__()
self.alpha = alpha
# 支持两种方式,一种是奇偶相邻排列,一种是上下排列
assert rank in {'adjacent', 'updown'}, "rank kwarg only support 'adjacent' and 'updown' "
self.rank = rank
self.loss_sup = nn.CrossEntropyLoss()
self.loss_rdrop = nn.KLDivLoss(reduction='none')
def forward(self, *args):
'''支持两种方式: 一种是y_pred, y_true, 另一种是y_pred1, y_pred2, y_true
'''
assert len(args) in {2, 3}, 'RDropLoss only support 2 or 3 input args'
# y_pred是1个Tensor
if len(args) == 2:
y_pred, y_true = args
loss_sup = self.loss_sup(y_pred, y_true) # 两个都算
if self.rank == 'adjacent':
y_pred1 = y_pred[1::2]
y_pred2 = y_pred[::2]
elif self.rank == 'updown':
half_btz = y_true.shape[0] // 2
y_pred1 = y_pred[:half_btz]
y_pred2 = y_pred[half_btz:]
# y_pred是两个tensor
else:
y_pred1, y_pred2, y_true = args
loss_sup = self.loss_sup(y_pred1, y_true)
loss_rdrop1 = self.loss_rdrop(F.log_softmax(y_pred1, dim=-1), F.softmax(y_pred2, dim=-1))
loss_rdrop2 = self.loss_rdrop(F.log_softmax(y_pred2, dim=-1), F.softmax(y_pred1, dim=-1))
return loss_sup + torch.mean(loss_rdrop1 + loss_rdrop2) / 4 * self.alpha
class UDALoss(nn.Module):
'''UDALoss,使用时候需要继承一下,因为forward需要使用到global_step和total_steps
https://arxiv.org/abs/1904.12848
'''
def __init__(self, tsa_schedule=None, total_steps=None, start_p=0, end_p=1, return_all_loss=True):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.loss_unsup = nn.KLDivLoss(reduction='batchmean')
self.tsa_schedule = tsa_schedule
self.start = start_p
self.end = end_p
if self.tsa_schedule:
assert self.tsa_schedule in {'linear_schedule', 'exp_schedule', 'log_schedule'}, 'tsa_schedule config illegal'
self.return_all_loss = return_all_loss
def forward(self, y_pred, y_true_sup, global_step, total_steps):
sup_size = y_true_sup.size(0)
unsup_size = (y_pred.size(0) - sup_size) // 2
# 有监督部分, 用交叉熵损失
y_pred_sup = y_pred[:sup_size]
if self.tsa_schedule is None:
loss_sup = self.loss_sup(y_pred_sup, y_true_sup)
else: # 使用tsa来去掉预测概率较高的有监督样本
threshold = self.get_tsa_threshold(self.tsa_schedule, global_step, total_steps, self.start, self.end)
true_prob = torch.gather(F.softmax(y_pred_sup, dim=-1), dim=1, index=y_true_sup[:, None])
sel_rows = true_prob.lt(threshold).sum(dim=-1).gt(0) # 仅保留小于阈值的样本
loss_sup = self.loss_sup(y_pred_sup[sel_rows], y_true_sup[sel_rows]) if sel_rows.sum() > 0 else 0
# 无监督部分,这里用KL散度,也可以用交叉熵
y_true_unsup = y_pred[sup_size:sup_size+unsup_size]
y_true_unsup = F.softmax(y_true_unsup.detach(), dim=-1)
y_pred_unsup = F.log_softmax(y_pred[sup_size+unsup_size:], dim=-1)
loss_unsup = self.loss_unsup(y_pred_unsup, y_true_unsup)
if self.return_all_loss:
return loss_sup + loss_unsup, loss_sup, loss_unsup
else:
return loss_sup + loss_unsup
@ staticmethod
def get_tsa_threshold(schedule, global_step, num_train_steps, start, end):
training_progress = global_step / num_train_steps
if schedule == "linear_schedule":
threshold = training_progress
elif schedule == "exp_schedule":
scale = 5
threshold = math.exp((training_progress - 1) * scale)
elif schedule == "log_schedule":
scale = 5
threshold = 1 - math.exp((-training_progress) * scale)
return threshold * (end - start) + start
class TemporalEnsemblingLoss(nn.Module):
'''TemporalEnsembling的实现,思路是在监督loss的基础上,增加一个mse的一致性损失loss
官方项目:https://github.com/s-laine/tempens
pytorch第三方实现:https://github.com/ferretj/temporal-ensembling
使用的时候,train_dataloader的shffle必须未False
'''
def __init__(self, epochs, max_val=10.0, ramp_up_mult=-5.0, alpha=0.5, max_batch_num=100, hist_device='cpu'):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.max_epochs = epochs
self.max_val = max_val
self.ramp_up_mult = ramp_up_mult
self.alpha = alpha
self.max_batch_num = max_batch_num # 设置未None表示记录全部数据历史,数据量大时耗资源
self.hist_unsup = [] # 历史无监督logit
self.hist_sup = [] # 历史监督信息logit
self.hist_device = hist_device
self.hist_input_y = [] # 历史监督标签y
assert (self.alpha >= 0) & (self.alpha < 1) # 等于1的时候upata写分母为0
def forward(self, y_pred_sup, y_pred_unsup, y_true_sup, epoch, bti):
self.same_batch_check(y_pred_sup, y_pred_unsup, y_true_sup, bti)
if (self.max_batch_num is None) or (bti < self.max_batch_num):
self.init_hist(bti, y_pred_sup, y_pred_unsup) # 初始化历史
sup_ratio = float(len(y_pred_sup)) / (len(y_pred_sup) + len(y_pred_unsup)) # 监督样本的比例
w = self.weight_schedule(epoch, sup_ratio)
sup_loss, unsup_loss = self.temporal_loss(y_pred_sup, y_pred_unsup, y_true_sup, bti)
# 更新
self.hist_unsup[bti] = self.update(self.hist_unsup[bti], y_pred_unsup.detach(), epoch)
self.hist_sup[bti] = self.update(self.hist_sup[bti], y_pred_sup.detach(), epoch)
# if bti == 0: $ 用于检查每个epoch数据顺序是否一致
# print(w, sup_loss.item(), w * unsup_loss.item())
# print(y_true_sup)
return sup_loss + w * unsup_loss, sup_loss, w * unsup_loss
else:
return self.loss_sup(y_pred_sup, y_true_sup)
def same_batch_check(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
'''检测数据的前几个batch必须是一致的, 这里写死是10个
'''
if bti >= 10:
return
if bti >= len(self.hist_input_y):
self.hist_input_y.append(y_true_sup.to(self.hist_device))
else: # 检测
err_msg = 'TemporalEnsemblingLoss requests the same sort dataloader, you may need to set train_dataloader shuffle=False'
assert self.hist_input_y[bti].equal(y_true_sup.to(self.hist_device)), err_msg
def update(self, hist, y_pred, epoch):
'''更新历史logit,利用alpha门控来控制比例
'''
Z = self.alpha * hist.to(y_pred) + (1. -self.alpha) * y_pred
output = Z * (1. / (1. - self.alpha ** (epoch + 1)))
return output.to(self.hist_device)
def weight_schedule(self, epoch, sup_ratio):
max_val = self.max_val * sup_ratio
if epoch == 0:
return 0.
elif epoch >= self.max_epochs:
return max_val
return max_val * np.exp(self.ramp_up_mult * (1. - float(epoch) / self.max_epochs) ** 2)
def temporal_loss(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
# MSE between current and temporal outputs
def mse_loss(out1, out2):
quad_diff = torch.sum((F.softmax(out1, dim=1) - F.softmax(out2, dim=1)) ** 2)
return quad_diff / out1.data.nelement()
sup_loss = self.loss_sup(y_pred_sup, y_true_sup)
# 原来实现是sup和unsup作为一个tensor,整体计算的,这里由于是拆分成两个tensor,因此分开算
unsup_loss = mse_loss(y_pred_unsup, self.hist_unsup[bti].to(y_pred_unsup))
unsup_loss += mse_loss(y_pred_sup, self.hist_sup[bti].to(y_pred_sup))
return sup_loss, unsup_loss
def init_hist(self, bti, y_pred_sup, y_pred_unsup):
if bti >= len(self.hist_sup):
self.hist_sup.append(torch.zeros_like(y_pred_sup).to(self.hist_device))
self.hist_unsup.append(torch.zeros_like(y_pred_unsup).to(self.hist_device))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
from torch.optim.lr_scheduler import LambdaLR
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""带warmup的schedule, 源自transformers包optimization.py中
参数
num_warmup_steps:
需要warmup的步数, 一般为 num_training_steps * warmup_proportion(warmup的比例, 建议0.05-0.15)
num_training_steps:
总的训练步数, 一般为 train_batches * num_epoch
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def extend_with_exponential_moving_average(model, decay=0.999):
class ExponentialMovingAverage():
''' 模型权重的指数滑动平均, 不参加梯度更新,只是记录滑动平均的参数,给预测使用
注意区别于类似adam一类的自适应学习率优化器, 针对一阶二阶梯度的指数滑动平均, 两者完全不同
例子:
# 初始化
ema = ExponentialMovingAverage(model, 0.999)
# 训练过程中, 更新完参数后, 同步update ema_weights weights
def train():
optimizer.step()
ema.step()
# eval前, 调用apply_ema_weights(); eval之后, restore_raw_weights()恢复原来模型的参数
def evaluate():
ema.apply_ema_weights()
# evaluate
# 如果想保存ema后的模型, 请在restore方法之前调用torch.save()
ema.restore_raw_weights()
'''
def __init__(self, model, decay):
self.model = model
self.decay = decay
# 保存ema权重(当前step的每一层的滑动平均权重)
self.ema_weights = {}
# 在进行evaluate的时候, 保存原始的模型权重, 当执行完evaluate后, 从ema权重恢复到原始权重
self.model_weights = {}
# 初始化ema_weights为model_weights
for name, param in self.model.named_parameters():
if param.requires_grad:
self.ema_weights[name] = param.data.clone()
def step(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.ema_weights
new_average = (1.0 - self.decay) * param.data + self.decay * self.ema_weights[name]
self.ema_weights[name] = new_average.clone()
def apply_ema_weights(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.ema_weights
self.model_weights[name] = param.data
param.data = self.ema_weights[name]
def restore_raw_weights(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.model_weights
param.data = self.model_weights[name]
self.model_weights = {}
return ExponentialMovingAverage(model, decay)
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
#!/bin/bash
# 该脚本针对bert单卡训练的结果日志处理,输出符合移动集采要求的模型信息。
# 模型源码的多卡训练,需要各厂商自行适配,该脚本需要适当修改
# 目标精度
target_acc=0.89
echo "------------bert base report---------------"
# 一个epoch处理的总样本数
total_samples=`cat log/fps.txt | grep "Epoch: 0" | grep "Current Train Samples" | awk '{sum+=$6} END {print sum}'`
echo "samples(one epoch): ${total_samples}"
echo "target acc: ${target_acc}"
echo "--------first achieve target acc-----------"
# 首次达到目标精度时的精度结果
first_achieve_acc=`cat log/fps.txt | grep "Current F1 score" | sort -n -k 6 | awk '$6 > '${target_acc}'' | sort -n -k 2 | head -n 1 | awk '{print$6}'`
# 首次达到目标精度时的epoch数
current_epoch=`cat log/fps.txt | grep "Current F1 score" | sort -n -k 6 | awk '$6 > '${target_acc}'' | sort -n -k 2 | head -n 1 | awk '{print$2}'`
# 首次达到目标精度时的训练时长
current_train_time=`cat log/fps.txt | grep "Epoch: ${current_epoch}[[:blank:]]" | grep "All Train Time" | awk '{print$6}' | sort -n -k 1 -r | head -n 1`
# 首次达到目标精度时的评估时长
current_eval_time=`cat log/fps.txt | grep "Epoch: ${current_epoch}[[:blank:]]" | grep "All Eval Time" | awk '{print$6}' | head -n 1`
# 首次达到目标精度时的总时长
current_total_time=`cat log/fps.txt | grep "Epoch: ${current_epoch}[[:blank:]]" | grep "All Time" | awk '{print$5}' | head -n 1`
# 当前训练最大FPS
current_max_FPS=`cat log/fps.txt | awk '$2 <= '${current_epoch}'' | grep "Current Epoch FPS" | awk '{x[$2]+=$6}END{for(i in x){print i, x[i]}}' | sort -n -k 2 -r | head -n 1 | awk '{print$2}'`
# 当前训练平均FPS
current_average_FPS=`awk -v ts=${total_samples} -v eps=${current_epoch} -v ttt=${current_train_time} 'BEGIN{print(ts*(eps+1)/ttt)}'`
# 当前端到端平均FPS
current_e2e_FPS=`awk -v ts=${total_samples} -v eps=${current_epoch} -v ttt=${current_total_time} 'BEGIN{print(ts*(eps+1)/ttt)}'`
echo "first achieve target acc: ${first_achieve_acc}"
current_epoch=`awk -v ce=${current_epoch} 'BEGIN{print(ce+1)}'`
echo "current epoch: ${current_epoch}"
echo "current train time: ${current_train_time}"
echo "current eval time: ${current_eval_time}"
echo "current total time: ${current_total_time}"
echo "current max FPS: ${current_max_FPS}"
echo "current average FPS: ${current_average_FPS}"
echo "current e2e FPS: ${current_e2e_FPS}"
echo "------------achieve best acc---------------"
# 达到最优精度时的精度结果
best_acc=`cat log/fps.txt | grep "Current F1 score" | sort -n -k 6 -r | awk '{print$6}' | head -n 1`
# 达到最优精度时的epoch数
best_acc_epoch=`cat log/fps.txt | grep "Current F1 score" | sort -n -k 6 -r -k 2 | awk '{print$2}' | head -n 1`
# 达到最优精度时的训练时长
current_train_time=`cat log/fps.txt | grep "Epoch: ${best_acc_epoch}[[:blank:]]" | grep "All Train Time" | awk '{print$6}' | sort -n -k 1 -r | head -n 1`
# 首次最优精度时的评估时长
current_eval_time=`cat log/fps.txt | grep "Epoch: ${best_acc_epoch}[[:blank:]]" | grep "All Eval Time" | awk '{print$6}' | head -n 1`
# 首次最优精度时的总时长
current_total_time=`cat log/fps.txt | grep "Epoch: ${best_acc_epoch}[[:blank:]]" | grep "All Time" | awk '{print$5}' | head -n 1`
# 当前训练最大FPS
current_max_FPS=`cat log/fps.txt | awk '$2 <= '${best_acc_epoch}'' | grep "Current Epoch FPS" | awk '{x[$2]+=$6}END{for(i in x){print i, x[i]}}' | sort -n -k 2 -r | head -n 1 | awk '{print$2}'`
# 当前训练平均FPS
current_average_FPS=`awk -v ts=${total_samples} -v eps=${best_acc_epoch} -v ttt=${current_train_time} 'BEGIN{print(ts*(eps+1)/ttt)}'`
# 当前端到端平均FPS
current_e2e_FPS=`awk -v ts=${total_samples} -v eps=${best_acc_epoch} -v ttt=${current_total_time} 'BEGIN{print(ts*(eps+1)/ttt)}'`
echo "best acc: ${best_acc}"
best_acc_epoch=`awk -v ce=${best_acc_epoch} 'BEGIN{print(ce+1)}'`
echo "best acc epoch: ${best_acc_epoch}"
echo "current train time: ${current_train_time}"
echo "current eval time: ${current_eval_time}"
echo "current total time: ${current_total_time}"
echo "current max FPS: ${current_max_FPS}"
echo "current average FPS: ${current_average_FPS}"
echo "current e2e FPS: ${current_e2e_FPS}"
echo "-------------total time-------------------"
# 总epoch数
epoch_num=`cat log/fps.txt | sort -n -k 2 -r | head -n 1 | awk '{print$2}'`
# 最小的开始时间
start_time=`cat log/time.txt | grep "Start" | sort -n -k 3 | awk '{print$3}'`
# 最晚的结束时间
end_time=`cat log/time.txt | grep "End" | sort -n -k 3 -r | awk '{print$3}'`
# 程序运行总时长
all_time=`awk -v st=${start_time} -v et=${end_time} 'BEGIN{print(et-st)}'`
epoch_num=`awk -v ce=${epoch_num} 'BEGIN{print(ce+1)}'`
echo "total epoch number: ${epoch_num}"
echo "all time: ${all_time}"
# 1. 文本分类
## 1.1 不同预训练模型的指标对比
- [情感分类数据集](https://github.com/bojone/bert4keras/blob/master/examples/datasets/sentiment.zip)+cls位分类
| solution | epoch | valid_acc | test_acc | comment |
| ---- | ---- | ---- | ---- | ---- |
| albert_small | 10/10 | 94.46 | 93.98 | small版本 |
| bert | 6/10 | 94.72 | 94.11 | —— |
| robert | 4/10 | 94.77 | 94.64 | —— |
| nezha | 7/10 | 95.07 | 94.72 | —— |
| xlnet | 6/10 | 95.00 | 94.24 | —— |
| electra | 10/10 | 94.94 | 94.78 | —— |
| roformer | 9/10 | 94.85 | 94.42 | —— |
| roformer_v2 | 3/10 | 95.78 | 96.09 | —— |
| gau_alpha | 2/10 | 95.25 | 94.46 | —— |
## 1.2 不同trick下的指标对比
- trick测试+[情感分类数据集](https://github.com/bojone/bert4keras/blob/master/examples/datasets/sentiment.zip)+cls分类+无segment_input
| solution | epoch | valid_acc | test_acc | comment |
| ---- | ---- | ---- | ---- | ---- |
| bert | 10/10 | 94.90 | 94.78 | —— |
| fgm | 4/10 | 95.34 | 94.99 | —— |
| pgd | 6/10 | 95.34 | 94.64 | —— |
| gradient_penalty | 7/10 | 95.07 | 94.81 | —— |
| vat | 8/10 | 95.21 | 95.03 | —— |
| ema | 7/10 | 95.21 | 94.86 | —— |
| ema+warmup | 7/10 | 95.51 | 95.12 | —— |
| mix_up | 6/10 | 95.12 | 94.42 | —— |
| R-drop | 9/10 | 95.25 | 94.94 | —— |
| UDA | 8/10 | 94.90 | 95.56 | —— |
| semi-vat | 10/10 | 95.34 | 95.38 | —— |
| temporal_ensembling | 8/10 | 94.94 | 94.90 | —— |
# 2. 序列标注
- [人民日报数据集](http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz)+bert预训练模型
- valid集指标
| solution | epoch | f1_token | f1_entity | comment |
| ---- | ---- | ---- | ---- | ---- |
| bert+crf | 18/20 | 96.89 | 96.05 | —— |
| bert+crf+init | 18/20 | 96.93 | 96.08 | 用训练数据初始化crf权重 |
| bert+crf+freeze | 11/20 | 96.89 | 96.13 | 用训练数据生成crf权重(不训练) |
| bert+cascade+crf | 5/20 | 98.10 | 96.26 | crf类别少所以f1_token偏高 |
| bert+crf+posseg | 13/20 | 97.32 | 96.55 | 加了词性输入 |
| bert+global_pointer | 18/20 | —— | 95.66 | —— |
| bert+efficient_global_pointer | 17/20 | —— | 96.55 | —— |
| bert+mrc | 7/20 | —— | 95.75 | —— |
| bert+span | 13/20 | —— | 96.31 | —— |
| bert+tplinker_plus | 20/20 | —— | 95.71 | 长度限制明显 |
| uie | 20/20 | —— | 96.57 | zeroshot:f1=60.8, fewshot-100样本:f1=85.82, 200样本:f1=86.40 |
| W2NER | 18/20 | 97.37 | 96.32 | 对显存要求较高 |
# 3. 文本表示
## 3.1 无监督语义相似度
- bert预训练模型 + 无监督finetune + cls位句向量(PromptBert除外)
- 五个中文数据集 + 5个epoch取最优值 + valid的spearmanr相关系数
- 继续finetune, 部分数据集有小幅提升
- 实验显示dropout_rate对结果影响较大
| solution | ATEC | BQ | LCQMC | PAWSX | STS-B | comment |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| Bert-whitening | 26.79 | 31.81| 56.34 | 17.22 | 67.45 | cls+不降维 |
| CT | 30.65 | 44.50| 68.67 | 16.20 | 69.27 | dropout=0.1, 收敛慢跑了10个epoch |
| CT_In_Batch_Neg | 32.47 | 47.09| 68.56 | 27.50 | 74.00 | dropout=0.1 |
| TSDAE | —— | 46.65| 65.30 | 12.54 | —— | dropout=0.1, ——表示该指标异常未记录 |
| SimCSE | 33.90 | 50.29| 71.81 | 13.14 | 71.09 | dropout=0.3 |
| ESimCSE | 34.05 | 50.54| 71.58 | 12.53 | 71.27 | dropout=0.3 |
| DiffSCE | 33.04 | 48.17| 71.51 | 12.91 | 71.10 | dropout=0.3, 没啥效果 |
| PromptBert | 33.98 | 49.89| 73.18 | 13.30 | 73.42 | dropout=0.3 |
## 3.2 有监督语义相似度
- bert预训练模型 + 训练数据finetune + cls位句向量
- 五个中文数据集 + 5个epoch取最优值 + valid/test的spearmanr相关系数
- STS-B任务是5分类,其余是2分类
| solution | ATEC | BQ | LCQMC | PAWSX | STS-B | comment |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| CoSENT |50.61 / 49.81|72.84 / 71.61|77.79 / 78.74|55.00 / 56.00|83.48 / 80.06| |
| ContrastiveLoss |50.02 / 49.19|72.52 / 70.98|77.49 / 78.27|58.21 / 57.65|69.87 / 68.58| STS-B转为2分类 |
| InfoNCE |47.77 / 46.99|69.86 / 68.14|71.74 / 74.54|52.82 / 54.21|83.31 / 78.72| STS-B转为2分类 |
|concat CrossEntropy|48.71 / 47.62|72.16 / 70.07|78.44 / 78.77|51.46 / 52.28|61.31 / 56.62| STS-B转为2分类 |
| CosineMSELoss |46.89 / 45.86|72.27 / 71.35|75.29 / 77.19|54.92 / 54.35|81.64 / 77.76| STS-B标准化到0-1 |
# 4. 关系提取
- [百度关系提取数据集](http://ai.baidu.com/broad/download?dataset=sked)
| solution | f1 | comment |
| ---- | ---- | ---- |
| CasRel | 81.87 | |
| gplinker | 81.88 | |
| tplinker | 74.49 | seq_len=64, 未完全收敛 |
| tplinker_plus | 79.30 | seq_len=64 |
# 5. 文本生成
- [CSL数据集](https://github.com/CLUEbenchmark/CLGE),注意是训练集1万左右的版本,分别dev/test指标
| solution | Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment |
| ---- | ---- | ---- | ---- | ---- | ---- |
|bert+unlim|63.65 / 63.01|66.25 / 66.34|54.48 / 54.81|44.21 / 44.60| |
| bart |64.62 / 64.99|67.72 / 68.40|56.08 / 57.26|46.15 / 47.67| |
| mt5 |67.67 / 65.98|70.39 / 69.36|59.60 / 59.05|50.34 / 50.11| |
|t5_pegasus|66.07 / 66.11|68.94 / 69.61|57.12 / 58.38|46.14 / 47.95| |
| uer_t5 |63.59 / 63.11|66.56 / 66.48|54.65 / 54.82|44.27 / 44.60| |
\ No newline at end of file
This diff is collapsed.
#! -*- coding: utf-8 -*-
# 测试代码可用性: 提取特征
import torch
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
root_model_path = "F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
tokenizer = Tokenizer(vocab_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重
# 编码测试
token_ids, segment_ids = tokenizer.encode(u'语言模型')
token_ids, segment_ids = torch.tensor([token_ids]), torch.tensor([segment_ids])
print('\n ===== predicting =====\n')
model.eval()
with torch.no_grad():
print(model([token_ids, segment_ids])[0])
"""
输出:
[[[-0.63251007 0.2030236 0.07936534 ... 0.49122632 -0.20493352
0.2575253 ]
[-0.7588351 0.09651865 1.0718756 ... -0.6109694 0.04312154
0.03881441]
[ 0.5477043 -0.792117 0.44435206 ... 0.42449304 0.41105673
0.08222899]
[-0.2924238 0.6052722 0.49968526 ... 0.8604137 -0.6533166
0.5369075 ]
[-0.7473459 0.49431565 0.7185162 ... 0.3848612 -0.74090636
0.39056838]
[-0.8741375 -0.21650358 1.338839 ... 0.5816864 -0.4373226
0.56181806]]]
"""
\ No newline at end of file
#! -*- coding: utf-8 -*-
# 测试代码可用性: 结合MLM的Gibbs采样
from tqdm import tqdm
import numpy as np
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch
import torch.nn as nn
root_model_path = "F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
tokenizer = Tokenizer(vocab_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm='softmax'
) # 建立模型,加载权重
sentences = []
init_sent = u'科学技术是第一生产力。' # 给定句子或者None
minlen, maxlen = 8, 32
steps = 10000
converged_steps = 1000
vocab_size = tokenizer._vocab_size
if init_sent is None:
length = np.random.randint(minlen, maxlen + 1)
tokens = ['[CLS]'] + ['[MASK]'] * length + ['[SEP]']
token_ids = tokenizer.tokens_to_ids(tokens)
segment_ids = [0] * len(token_ids)
else:
token_ids, segment_ids = tokenizer.encode(init_sent)
length = len(token_ids) - 2
device='cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
with torch.no_grad():
for _ in tqdm(range(steps), desc='Sampling'):
# Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
i = np.random.choice(length) + 1
token_ids[i] = tokenizer._token_mask_id
token_ids_tensor, segment_ids_tensor = torch.tensor([token_ids], device=device), torch.tensor([segment_ids], device=device)
_, probas = model([token_ids_tensor, segment_ids_tensor])
probas = probas[0, i]
token = np.random.choice(vocab_size, p=probas.cpu().numpy())
token_ids[i] = token
sentences.append(tokenizer.decode(token_ids))
print(u'部分随机采样结: ')
for _ in range(10):
print(np.random.choice(sentences[converged_steps:]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment