Commit 66a1d0d0 authored by yangzhong's avatar yangzhong
Browse files

提交初版bert4torch project

parents
MIT License
Copyright (c) 2022 Bo仔很忙
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# bert4torch
## 模型介绍
bert4torch是一个基于pytorch的训练框架,前期以效仿和实现bert4keras的主要功能为主,方便加载多类预训练模型进行finetune,提供了中文注释方便用户理解模型结构。
## 模型结构
BERT的主模型是BERT中最重要组件,BERT通过预训练(pre-training),具体来说,就是在主模型后再接个专门的模块计算预训练的损失(loss),预训练后就得到了主模型的参数(parameter),当应用到下游任务时,就在主模型后接个跟下游任务配套的模块,然后主模型赋上预训练的参数,下游任务模块随机初始化,然后微调(fine-tuning)就可以了(注意:微调的时候,主模型和下游任务模块两部分的参数一般都要调整,也可以冻结一部分,调整另一部分)。
主模型由三部分构成:**嵌入层****编码器****池化层**
如图:
![img](https://images.cnblogs.com/cnblogs_com/wangzb96/1789835/o_200618140451BERT%E4%B9%8B%E4%B8%BB%E6%A8%A1%E5%9E%8B.png)
其中
- 输入:一个个小批(mini-batch),小批里是`batch_size`个序列(句子或句子对),每个序列由若干个离散编码向量组成。
- 嵌入层:将输入的序列转换成连续分布式表示(distributed representation),即词嵌入(word embedding)或词向量(word vector)。
- 编码器:对每个序列进行非线性表示。
- 池化层:取出`[CLS]`标记(token)的表示(representation)作为整个序列的表示。
- 输出:编码器最后一层输出的表示(序列中每个标记的表示)和池化层输出的表示(序列整体的表示)。
## 环境配置
### Docker
在光源可拉取docker镜像,拉取方式如下:
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.13.1-centos7.6-dtk-23.04-py37-latest
```
安装依赖包和bert4torch
```
pip install -r requirements.txt
cd bert4torch
python3 setup.py install
```
## 数据集和预训练模型
数据集下载地址:https://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz,人民日报数据集存放在目录/datasets/bert-base-chinese目录下,然后解压。
预训练模型下载地址:https://huggingface.co/bert-base-chinese/tree/main,所有文件下载存放在目录/datasets/bert-base-chinese下。
训练数据目录结构如下:
```
dataset
|
bert-base-chinese
|
china-people-daily-ner-corpus config.json flax_model.msgpack pytorch_model.bin vocab.txt
|
example.dev example.test example.train
```
## 训练
### 修改配置文件
```
cd examples/sequence_labeling/
# 修改训练脚本配置文件
crf.py # 单卡训练脚本
crf_ddp.py # 多卡训练脚本 多卡训练使用torch的ddp,在单卡训练代码基础上增加DDP的相关内容
仅修改配置文件路径,包括config_path, checkpoint_path, dict_path, train_dataloader, valid_dtaloader,根据需要调整batch_size大小。
注:如果需要测试fp16,可以修改crf_ddp.py和crf.py中model.compile(),添加use_amp=True。
```
### 单机单卡
```
cd examples/sequence_labeling/
./single_train.sh
```
### 单机多卡
```
cd examples/sequence_labeling/
./multi_train.sh
```
## 精度数据
| 卡数 | 类型 | batch_size | f1 | p | r |
| ---- | ---- | ---------- | ------ | ------ | ------ |
| 1 | fp32 | 64 | 0.9592 | 0.9643 | 0.9617 |
| 1 | fp16 | 64 | 0.9559 | 0.9596 | 0.9545 |
| 4 | fp32 | 256 | 0.9459 | 0.9398 | 0.9521 |
| 4 | fp16 | 256 | 0.9438 | 0.9398 | 0.9505 |
## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/bert4torch
## 参考资料
- https://github.com/Tongjilibo/bert4torch
# bert4torch使用教程
## 1. 建模流程示例
```python
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 加载数据集,可以自己继承Dataset来定义
class MyDataset(ListDataset):
@staticmethod
def load_data(filenames):
"""读取文本文件,整理成需要的格式
"""
D = []
return D
def collate_fn(batch):
'''处理上述load_data得到的batch数据,整理成对应device上的Tensor
注意:返回值分为feature和label, feature可整理成list或tuple
'''
batch_token_ids, batch_segment_ids, batch_labels = [], [], []
return [batch_token_ids, batch_segment_ids], batch_labels.flatten()
# 加载数据集
train_dataloader = DataLoader(MyDataset('file_path'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
# 定义bert上的模型结构,以文本二分类为例
class Model(BaseModel):
def __init__(self) -> None:
super().__init__()
self.bert = build_transformer_model(config_path, checkpoint_path, with_pool=True)
self.dropout = nn.Dropout(0.1)
self.dense = nn.Linear(768, 2)
def forward(self, token_ids, segment_ids):
# build_transformer_model得到的模型仅接受list/tuple传参,因此入参只有一个时候包装成[token_ids]
hidden_states, pooled_output = self.bert([token_ids, segment_ids])
output = self.dropout(pooled_output)
output = self.dense(output)
return output
model = Model().to(device)
# 定义使用的loss和optimizer,这里支持自定义
model.compile(
loss=nn.CrossEntropyLoss(), # 可以自定义Loss
optimizer=optim.Adam(model.parameters(), lr=2e-5), # 可以自定义优化器
scheduler=None, # 可以自定义scheduler
metrics=['accuracy']
)
# 定义评价函数
def evaluate(data):
total, right = 0., 0.
for x_true, y_true in data:
y_pred = model.predict(x_true).argmax(axis=1)
total += len(y_true)
right += (y_true == y_pred).sum().item()
return right / total
class Evaluator(Callback):
"""评估与保存,这里定义仅在epoch结束后调用
"""
def __init__(self):
self.best_val_acc = 0.
def on_epoch_end(self, global_step, epoch, logs=None):
val_acc = evaluate(valid_dataloader)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
model.save_weights('best_model.pt')
print(f'val_acc: {val_acc:.5f}, best_val_acc: {self.best_val_acc:.5f}\n')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=100, grad_accumulation_steps=2, callbacks=[evaluator])
```
## 2. 主要模块讲解
### 1) 数据处理部分
#### a. 精简词表,并建立分词器
```python
token_dict, keep_tokens = load_vocab(
dict_path=dict_path, # 词典文件路径
simplified=True, # 过滤冗余部分token,如[unused1]
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], # 指定起始的token,如[UNK]从bert默认的103位置调整到1
)
tokenizer = Tokenizer(token_dict, do_lower_case=True) # 若无需精简,仅使用当前行定义tokenizer即可
```
#### b. 好用的小函数
- `text_segmentate()`: 截断总长度至不超过maxlen, 接受多个sequence输入,每次截断最长的句子,indices表示删除的token位置
- `tokenizer.encode()`: 把text转成token_ids,默认句首添加[CLS],句尾添加[SEP],返回token_ids和segment_ids,相当于同时调用`tokenizer.tokenize()``tokenizer.tokens_to_ids()`
- `tokenizer.decode()`: 把token_ids转成text,默认会删除[CLS], [SEP], [UNK]等特殊字符,相当于调用`tokenizer.ids_to_tokens()`并做了一些后处理
- `sequence_padding`: 将序列padding到同一长度, 传入一个元素为list, ndarray, tensor的list,返回ndarry或tensor
### 2) 模型定义部分
- 模型创建
```python
'''
调用模型后,若设置with_pool, with_nsp, with_mlm,则返回值依次为[hidden_states, pool_emb/nsp_emb, mlm_scores],否则只返回hidden_states
'''
build_transformer_model(
config_path=config_path, # 模型的config文件地址
checkpoint_path=checkpoint_path, # 模型文件地址,默认值None表示不加载预训练模型
model='bert', # 加载的模型结构,这里Model也可以基于nn.Module自定义后传入
application='encoder', # 模型应用,支持encoder,lm和unilm格式
segment_vocab_size=2, # type_token_ids数量,默认为2,如不传入segment_ids则需设置为0
with_pool=False, # 是否包含Pool部分
with_nsp=False, # 是否包含NSP部分
with_mlm=False, # 是否包含MLM部分
return_model_config=False, # 是否返回模型配置参数
output_all_encoded_layers=False, # 是否返回所有hidden_state层
)
```
- 定义loss,optimizer,scheduler等
```python
'''
定义使用的loss和optimizer,这里支持自定义
'''
model.compile(
loss=nn.CrossEntropyLoss(), # 可以自定义Loss
optimizer=optim.Adam(model.parameters(), lr=2e-5), # 可以自定义优化器
scheduler=None, # 可以自定义scheduler
adversarial_train={'name': 'fgm'}, # 训练trick方案设置,支持fgm, pgd, gradient_penalty, vat
metrics=['accuracy'] # loss等默认打印的字段无需设置
)
```
- 自定义模型
```python
'''
基于bert上层的各类魔改,如last2layer_average, token_first_last_average
'''
class Model(BaseModel):
# 需要继承BaseModel
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path, checkpoint_path)
def forward(self):
pass
```
- [自定义训练过程](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_custom_fit_progress.py)
```python
'''
自定义fit过程,适用于自带fit()不满足需求时
'''
class Model(BaseModel):
def fit(self, train_dataloader, steps_per_epoch, epochs):
train_dataloader = cycle(train_dataloader)
self.train()
for epoch in range(epochs):
for bti in range(steps_per_epoch):
train_X, train_y = next(train_dataloader)
output = self.forward(*train_X)
loss = self.criterion(output, train_y)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
```
- 模型保存和加载
```python
'''
prefix: 是否以原始的key来保存,如word_embedding原始key为bert.embeddings.word_embeddings.weight
默认为None表示不启用, 若基于BaseModel自定义模型,需指定为bert模型对应的成员变量名,直接使用设置为''
主要是为了别的训练框架容易加载
'''
model.save_weights(save_path, prefix=None)
model.load_weights(load_path, strict=True, prefix=None)
```
- [加载transformers模型进行训练](https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_load_transformers_model.py)
```python
from transformers import AutoModelForSequenceClassification
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained("file_path", num_labels=2)
def forward(self, token_ids, attention_mask, segment_ids):
output = self.bert(input_ids=token_ids, attention_mask=attention_mask, token_type_ids=segment_ids)
return output.logits
```
### 3) 模型评估部分
```python
'''支持在多个位置执行
'''
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_acc = 0.
def on_train_begin(self, logs=None): # 训练开始时候
pass
def on_train_end(self, logs=None): # 训练结束时候
pass
def on_batch_begin(self, global_step, batch, logs=None): # batch开始时候
pass
def on_batch_end(self, global_step, batch, logs=None): # batch结束时候
# 可以设置每隔多少个step,后台记录log,写tensorboard等
# 尽量不要在batch_begin和batch_end中print,防止打断进度条功能
pass
def on_epoch_begin(self, global_step, epoch, logs=None): # epoch开始时候
pass
def on_epoch_end(self, global_step, epoch, logs=None): # epoch结束时候
val_acc = evaluate(valid_dataloader)
if val_acc > self.best_val_acc:
self.best_val_acc = val_acc
model.save_weights('best_model.pt')
print(f'val_acc: {val_acc:.5f}, best_val_acc: {self.best_val_acc:.5f}\n')
```
## 3. 其他特性讲解
### 1) 单机多卡训练
#### a. 使用DataParallel
```python
'''DP有两种方式,第一种是forward只计算logit,第二种是forward直接计算loss
建议使用第二种,可以部分缓解负载不均衡的问题
'''
from bert4torch.models import BaseModelDP
# ===========处理数据和定义model===========
model = BaseModelDP(model) # 指定DP模式使用多gpu
model.compile(
loss=lambda x, _: x.mean(), # 多个gpu计算的loss的均值
optimizer=optim.Adam(model.parameters(), lr=2e-5),
)
```
#### b. 使用DistributedDataParallel
```python
'''DDP使用torch.distributed.launch,从命令行启动
'''
# 需要定义命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=-1)
args = parser.parse_args()
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
torch.distributed.init_process_group(backend='nccl')
# ===========处理数据和定义model===========
# 指定DDP模型使用多gpu, master_rank为指定用于打印训练过程的local_rank
model = BaseModelDDP(model, master_rank=0, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False)
# 定义使用的loss和optimizer,这里支持自定义
model.compile(
loss=lambda x, _: x, # 直接把forward计算的loss传出来
optimizer=optim.Adam(model.parameters(), lr=2e-5),
)
```
### 2) tensorboard保存训练过程
```python
from tensorboardX import SummaryWriter
class Evaluator(Callback):
"""每隔多少个step评估并记录tensorboard
"""
def on_batch_end(self, global_step, batch, logs=None):
if global_step % 100 == 0:
writer.add_scalar(f"train/loss", logs['loss'], global_step)
val_acc = evaluate(valid_dataloader)
writer.add_scalar(f"valid/acc", val_acc, global_step)
```
### 3) 打印训练参数
```python
from torchinfo import summary
summary(model, input_data=next(iter(train_dataloader))[0])
```
\ No newline at end of file
Metadata-Version: 2.1
Name: bert4torch
Version: 0.1.9
Summary: an elegant bert4torch
Home-page: https://github.com/Tongjilibo/bert4torch
Author: Tongjilibo
License: MIT Licence
Platform: UNKNOWN
License-File: LICENSE
bert4torch: https://github.com/Tongjilibo/bert4torch
LICENSE
README.md
setup.py
bert4torch/__init__.py
bert4torch/activations.py
bert4torch/layers.py
bert4torch/losses.py
bert4torch/models.py
bert4torch/optimizers.py
bert4torch/snippets.py
bert4torch/tokenizers.py
bert4torch.egg-info/PKG-INFO
bert4torch.egg-info/SOURCES.txt
bert4torch.egg-info/dependency_links.txt
bert4torch.egg-info/requires.txt
bert4torch.egg-info/top_level.txt
\ No newline at end of file
#! -*- coding: utf-8 -*-
__version__ = '0.1.9'
\ No newline at end of file
# 从transformer中移植过来的activation, 原来的bert4keras并没有
import math
import torch
from packaging import version
from torch import nn
def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = nn.functional.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = nn.functional.silu
def _mish_python(x):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x))
if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def linear_act(x):
return x
ACT2FN = {
"relu": nn.functional.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
"softmax": nn.Softmax(dim=-1)
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
import torch
from torch.functional import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from bert4torch.snippets import get_sinusoid_encoding_table, take_along_dim
from bert4torch.activations import get_activation
from typing import List, Optional
import random
import warnings
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12, conditional_size=False, weight=True, bias=True, norm_mode='normal', **kwargs):
"""layernorm 层,这里自行实现,目的是为了兼容 conditianal layernorm,使得可以做条件文本生成、条件分类等任务
条件layernorm来自于苏剑林的想法,详情:https://spaces.ac.cn/archives/7124
"""
super(LayerNorm, self).__init__()
# 兼容roformer_v2不包含weight
if weight:
self.weight = nn.Parameter(torch.ones(hidden_size))
# 兼容t5不包含bias项, 和t5使用的RMSnorm
if bias:
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.norm_mode = norm_mode
self.eps = eps
self.conditional_size = conditional_size
if conditional_size:
# 条件layernorm, 用于条件文本生成,
# 这里采用全零初始化, 目的是在初始状态不干扰原来的预训练权重
self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False)
self.dense1.weight.data.uniform_(0, 0)
self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False)
self.dense2.weight.data.uniform_(0, 0)
def forward(self, x):
inputs = x[0]
if self.norm_mode == 'rmsnorm':
# t5使用的是RMSnorm
variance = inputs.to(torch.float32).pow(2).mean(-1, keepdim=True)
o = inputs * torch.rsqrt(variance + self.eps)
else:
u = inputs.mean(-1, keepdim=True)
s = (inputs - u).pow(2).mean(-1, keepdim=True)
o = (inputs - u) / torch.sqrt(s + self.eps)
if not hasattr(self, 'weight'):
self.weight = 1
if not hasattr(self, 'bias'):
self.bias = 0
if self.conditional_size:
cond = x[1]
for _ in range(len(inputs.shape) - len(cond.shape)):
cond = cond.unsqueeze(dim=1)
return (self.weight + self.dense1(cond)) * o + (self.bias + self.dense2(cond))
else:
return self.weight * o + self.bias
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, attention_scale=True,
return_attention_scores=False, bias=True, **kwargs):
super(MultiHeadAttentionLayer, self).__init__()
assert hidden_size % num_attention_heads == 0
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.attention_scale = attention_scale
self.return_attention_scores = return_attention_scores
self.bias = bias
self.q = nn.Linear(hidden_size, hidden_size, bias=bias)
self.k = nn.Linear(hidden_size, hidden_size, bias=bias)
self.v = nn.Linear(hidden_size, hidden_size, bias=bias)
self.o = nn.Linear(hidden_size, hidden_size, bias=bias)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.a_bias, self.p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
if self.p_bias == 'typical_relative': # nezha
self.relative_positions_encoding = RelativePositionsEncoding(qlen=kwargs.get('max_position'),
klen=kwargs.get('max_position'),
embedding_size=self.attention_head_size,
max_relative_position=kwargs.get('max_relative_position'))
elif self.p_bias == 'rotary': # roformer
self.relative_positions_encoding = RoPEPositionEncoding(max_position=kwargs.get('max_position'), embedding_size=self.attention_head_size)
elif self.p_bias == 't5_relative': # t5
self.relative_positions = RelativePositionsEncodingT5(qlen=kwargs.get('max_position'),
klen=kwargs.get('max_position'),
relative_attention_num_buckets=kwargs.get('relative_attention_num_buckets'),
is_decoder=kwargs.get('is_decoder'))
self.relative_positions_encoding = nn.Embedding(kwargs.get('relative_attention_num_buckets'), self.num_attention_heads)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
# hidden_states shape: [batch_size, seq_q, hidden_size]
# attention_mask shape: [batch_size, 1, 1, seq_q] 或者 [batch_size, 1, seq_q, seq_q]
# encoder_hidden_states shape: [batch_size, seq_k, hidden_size]
# encoder_attention_mask shape: [batch_size, 1, 1, seq_k]
mixed_query_layer = self.q(hidden_states)
if encoder_hidden_states is not None:
mixed_key_layer = self.k(encoder_hidden_states)
mixed_value_layer = self.v(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.k(hidden_states)
mixed_value_layer = self.v(hidden_states)
# mixed_query_layer shape: [batch_size, query_len, hidden_size]
# mixed_query_layer shape: [batch_size, key_len, hidden_size]
# mixed_query_layer shape: [batch_size, value_len, hidden_size]
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# query_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size]
# key_layer shape: [batch_size, num_attention_heads, key_len, attention_head_size]
# value_layer shape: [batch_size, num_attention_heads, value_len, attention_head_size]
if self.p_bias == 'rotary':
query_layer = self.relative_positions_encoding(query_layer)
key_layer = self.relative_positions_encoding(key_layer)
# 交换k的最后两个维度,然后q和k执行点积, 获得attention score
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# attention_scores shape: [batch_size, num_attention_heads, query_len, key_len]
if (self.p_bias == 'typical_relative') and hasattr(self, 'relative_positions_encoding'):
relations_keys = self.relative_positions_encoding(attention_scores.shape[-1], attention_scores.shape[-1]) # [to_seq_len, to_seq_len, d_hid]
# 旧实现,方便读者理解维度转换
# query_layer_t = query_layer.permute(2, 0, 1, 3)
# query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, self.attention_head_size)
# key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
# key_position_scores_r = key_position_scores.view(from_seq_length, batch_size, num_attention_heads, from_seq_length)
# key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
# 新实现
key_position_scores_r_t = torch.einsum('bnih,ijh->bnij', query_layer, relations_keys)
attention_scores = attention_scores + key_position_scores_r_t
elif (self.p_bias == 't5_relative') and hasattr(self, 'relative_positions_encoding'):
relations_keys = self.relative_positions(attention_scores.shape[-1], attention_scores.shape[-1])
key_position_scores_r_t = self.relative_positions_encoding(relations_keys).permute([2, 0, 1]).unsqueeze(0)
attention_scores = attention_scores + key_position_scores_r_t
# 是否进行attention scale
if self.attention_scale:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 执行attention mask,对于mask为0部分的attention mask,
# 值为-1e10,经过softmax后,attention_probs几乎为0,所以不会attention到mask为0的部分
if attention_mask is not None:
# attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e10)
attention_mask = (1.0 - attention_mask) * -10000.0 # 所以传入的mask的非padding部分为1, padding部分为0
attention_scores = attention_scores + attention_mask
# 将attention score 归一化到0-1
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer) # [batch_size, num_attention_heads, query_len, attention_head_size]
if (self.p_bias == 'typical_relative') and hasattr(self, 'relative_positions_encoding'):
relations_values = self.relative_positions_encoding(attention_scores.shape[-1], attention_scores.shape[-1])
# 旧实现,方便读者理解维度转换
# attention_probs_t = attention_probs.permute(2, 0, 1, 3)
# attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, to_seq_length)
# value_position_scores = torch.matmul(attentions_probs_r, relations_values)
# value_position_scores_r = value_position_scores.view(from_seq_length, batch_size, num_attention_heads, self.attention_head_size)
# value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
# 新实现
value_position_scores_r_t = torch.einsum('bnij,ijh->bnih', attention_probs, relations_values)
context_layer = context_layer + value_position_scores_r_t
# context_layer shape: [batch_size, query_len, num_attention_heads, attention_head_size]
# transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,
# 所以在调用view之前,需要contiguous来返回一个contiguous copy;
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# 是否返回attention scores
if self.return_attention_scores:
# 这里返回的attention_scores没有经过softmax, 可在外部进行归一化操作
return self.o(context_layer), attention_scores
else:
return self.o(context_layer)
class PositionWiseFeedForward(nn.Module):
def __init__(self, hidden_size, intermediate_size, dropout_rate=0.5, hidden_act='gelu', is_dropout=False, bias=True, **kwargs):
# 原生的tf版本的bert在激活函数后,没有添加dropout层,但是在google AI的bert-pytorch开源项目中,多了一层dropout;
# 并且在pytorch官方的TransformerEncoderLayer的实现中,也有一层dropout层,就像这样:self.linear2(self.dropout(self.activation(self.linear1(src))));
# 这样不统一做法的原因不得而知,不过有没有这一层,差别可能不会很大;
# 为了适配是否dropout,用is_dropout,dropout_rate两个参数控制;如果是实现原始的transformer,直接使用默认参数即可;如果是实现bert,则is_dropout为False,此时的dropout_rate参数并不会使用.
super(PositionWiseFeedForward, self).__init__()
self.is_dropout = is_dropout
self.intermediate_act_fn = get_activation(hidden_act)
self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=bias)
if self.is_dropout:
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# x shape: (batch size, seq len, hidden_size)
if self.is_dropout:
x = self.dropout(self.intermediate_act_fn(self.intermediateDense(x)))
else:
x = self.intermediate_act_fn(self.intermediateDense(x))
# x shape: (batch size, seq len, intermediate_size)
x = self.outputDense(x)
# x shape: (batch size, seq len, hidden_size)
return x
class GatedAttentionUnit(nn.Module):
'''门控注意力单元,
链接:https://arxiv.org/abs/2202.10447
介绍:https://kexue.fm/archives/8934
说明:没有加入加性相对位置编码
参考pytorch项目:https://github.com/lucidrains/FLASH-pytorch
'''
def __init__(self, hidden_size, attention_key_size, intermediate_size, attention_probs_dropout_prob, hidden_act,
is_dropout=False, attention_scale=True, bias=True, normalization='softmax_plus', **kwargs):
super().__init__()
self.intermediate_size = intermediate_size
self.attention_head_size = attention_key_size
self.attention_scale = attention_scale
self.is_dropout = is_dropout
self.normalization = normalization
self.hidden_fn = get_activation(hidden_act)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.i_dense = nn.Linear(hidden_size, self.intermediate_size*2+attention_key_size, bias=bias)
self.offsetscale = self.OffsetScale(attention_key_size, heads=2, bias=bias)
self.o_dense = nn.Linear(self.intermediate_size, hidden_size, bias=bias)
self.a_bias, self.p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
if self.p_bias == 'rotary': # RoPE
self.relative_positions_encoding = RoPEPositionEncoding(max_position=kwargs.get('max_position'), embedding_size=self.attention_head_size)
def forward(self, hidden_states, attention_mask):
# 投影变换
hidden_states = self.hidden_fn(self.i_dense(hidden_states))
u, v, qk = hidden_states.split([self.intermediate_size, self.intermediate_size, self.attention_head_size], dim=-1)
q, k = self.offsetscale(qk) # 仿射变换
# 加入RoPE
if self.p_bias == 'rotary':
q = self.relative_positions_encoding(q)
k = self.relative_positions_encoding(k)
# Attention
attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) # [btz, seq_len, seq_len]
if self.attention_scale:
# seq_len = hidden_states.shape[1]
# attention_scores = F.relu(attention_scores/seq_len) ** 2
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_mask = (1.0 - attention_mask) * -1e12
attention_scores = attention_scores + attention_mask.squeeze(1)
# 归一化
attention_scores = self.attention_normalize(attention_scores, -1, self.normalization)
if self.is_dropout:
attention_scores = self.dropout(attention_scores)
# 计算输出
out = self.o_dense(u * torch.einsum('b i j, b j d -> b i d', attention_scores, v))
return out
def attention_normalize(self, a, dim=-1, method='softmax'):
"""不同的注意力归一化方案
softmax:常规/标准的指数归一化;
squared_relu:来自 https://arxiv.org/abs/2202.10447 ;
softmax_plus:来自 https://kexue.fm/archives/8823 。
"""
if method == 'softmax':
return F.softmax(a, dim=dim)
else:
mask = (a > -1e11).float()
l = torch.maximum(torch.sum(mask, dim=dim, keepdims=True), torch.tensor(1).to(mask))
if method == 'squared_relu':
return F.relu(a)**2 / l
elif method == 'softmax_plus':
return F.softmax(a * torch.log(l) / torch.log(torch.tensor(512)).to(mask), dim=dim)
return a
class OffsetScale(nn.Module):
'''仿射变换
'''
def __init__(self, head_size, heads=1, bias=True):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, head_size))
self.bias = bias
if bias:
self.beta = nn.Parameter(torch.zeros(heads, head_size))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = torch.einsum('... d, h d -> ... h d', x, self.gamma)
if self.bias:
out = out + self.beta
return out.unbind(dim = -2)
class BertEmbeddings(nn.Module):
"""
embeddings层
构造word, position and token_type embeddings.
"""
def __init__(self, vocab_size, embedding_size, hidden_size, max_position, segment_vocab_size, shared_segment_embeddings, drop_rate, conditional_size=False, **kwargs):
super(BertEmbeddings, self).__init__()
self.shared_segment_embeddings = shared_segment_embeddings
self.word_embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 位置编码
if kwargs.get('p_bias') == 'sinusoid':
self.position_embeddings = SinusoidalPositionEncoding(max_position, embedding_size)
elif kwargs.get('p_bias') in {'rotary', 'typical_relative', 't5_relative', 'other_relative'}:
# 如果使用相对位置编码,则不声明PositionEmbeddings
pass
elif max_position > 0:
self.position_embeddings = nn.Embedding(max_position, embedding_size)
# segement编码
if (segment_vocab_size > 0) and (not shared_segment_embeddings):
self.segment_embeddings = nn.Embedding(segment_vocab_size, embedding_size)
# emb_scale
self.emb_scale = kwargs.get('emb_scale', 1) # transform_xl, xlnet特有
# LayerNorm
self.layerNorm = LayerNorm(embedding_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.dropout = nn.Dropout(drop_rate)
# 如果embedding_size != hidden_size,则再有一个linear(适用于albert矩阵分解)
if embedding_size != hidden_size:
self.embedding_hidden_mapping_in = nn.Linear(embedding_size, hidden_size)
def forward(self, token_ids, segment_ids=None, conditional_emb=None, additional_embs=None):
if (not token_ids.requires_grad) and (token_ids.dtype in {torch.long, torch.int}):
words_embeddings = self.word_embeddings(token_ids)
else:
words_embeddings = token_ids # 自定义word_embedding,目前仅有VAT中使用
if hasattr(self, 'segment_embeddings'):
segment_ids = torch.zeros_like(token_ids) if segment_ids is None else segment_ids
segment_embeddings = self.segment_embeddings(segment_ids)
embeddings = words_embeddings + segment_embeddings
elif self.shared_segment_embeddings: # segment和word_embedding共享权重
segment_ids = torch.zeros_like(token_ids) if segment_ids is None else segment_ids
segment_embeddings = self.word_embeddings(segment_ids)
embeddings = words_embeddings + segment_embeddings
else:
embeddings = words_embeddings
# 额外的embedding,如词性等
if additional_embs is not None:
for emb in additional_embs:
embeddings += emb
if hasattr(self, 'position_embeddings'):
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(token_ids.shape[0], 1)
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
if self.emb_scale != 1:
embeddings = embeddings * self.emb_scale # transform_xl, xlnet特有
if hasattr(self, 'layerNorm'):
embeddings = self.layerNorm((embeddings, conditional_emb))
embeddings = self.dropout(embeddings)
if hasattr(self, 'embedding_hidden_mapping_in'):
embeddings = self.embedding_hidden_mapping_in(embeddings)
return embeddings
class BertLayer(nn.Module):
"""
Transformer层:
顺序为: Attention --> Add --> LayerNorm --> Feed Forward --> Add --> LayerNorm
注意: 1、以上都不计dropout层,并不代表没有dropout,每一层的dropout使用略有不同,注意区分
2、原始的Transformer的encoder中的Feed Forward层一共有两层linear,
config.intermediate_size的大小不仅是第一层linear的输出尺寸,也是第二层linear的输入尺寸
"""
def __init__(self, hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act,
is_dropout=False, conditional_size=False, **kwargs):
super(BertLayer, self).__init__()
self.multiHeadAttention = MultiHeadAttentionLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs)
self.dropout1 = nn.Dropout(dropout_rate)
self.layerNorm1 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.feedForward = PositionWiseFeedForward(hidden_size, intermediate_size, dropout_rate, hidden_act, is_dropout=is_dropout, **kwargs)
self.dropout2 = nn.Dropout(dropout_rate)
self.layerNorm2 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.is_decoder = kwargs.get('is_decoder')
if self.is_decoder:
self.crossAttention = MultiHeadAttentionLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs)
self.dropout3 = nn.Dropout(dropout_rate)
self.layerNorm3 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_attn_output = self.multiHeadAttention(hidden_states, attention_mask) # self.decoder为true时候,这里的attention_mask是三角的
hidden_states = hidden_states + self.dropout1(self_attn_output)
hidden_states = self.layerNorm1((hidden_states, conditional_emb))
# cross attention
if self.is_decoder and encoder_hidden_states is not None:
cross_attn_output = self.crossAttention(hidden_states, None, encoder_hidden_states, encoder_attention_mask)
hidden_states = hidden_states + self.dropout3(cross_attn_output)
hidden_states = self.layerNorm3((hidden_states, conditional_emb))
self_attn_output2 = self.feedForward(hidden_states)
hidden_states = hidden_states + self.dropout2(self_attn_output2)
hidden_states = self.layerNorm2((hidden_states, conditional_emb))
return hidden_states
class T5Layer(BertLayer):
"""T5的Encoder的主体是基于Self-Attention的模块
顺序:LN --> Att --> Add --> LN --> FFN --> Add
"""
def __init__(self, *args, version='t5.1.0', **kwargs):
super().__init__(*args, **kwargs)
# 如果是t5.1.1结构,则FFN层需要变更
if version.endswith('t5.1.1'):
kwargs['dropout_rate'] = args[2]
kwargs['hidden_act'] = args[5]
self.feedForward = self.T5PositionWiseFeedForward(hidden_size=args[0], intermediate_size=args[4], **kwargs)
# decoder中间有crossAttention
if self.is_decoder and hasattr(self.crossAttention, 'relative_positions_encoding'):
del self.crossAttention.relative_positions_encoding
del self.crossAttention.relative_positions
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
# bert的layernorm是在attn/ffc之后,Openai-gpt2是在之前
x = self.layerNorm1((hidden_states, conditional_emb))
self_attn_output = self.multiHeadAttention(x, attention_mask)
hidden_states = hidden_states + self.dropout1(self_attn_output)
# cross attention
if self.is_decoder and encoder_hidden_states is not None:
x = self.layerNorm3((hidden_states, conditional_emb))
cross_attn_output = self.crossAttention(x, None, encoder_hidden_states, encoder_attention_mask)
hidden_states = hidden_states + self.dropout3(cross_attn_output)
x = self.layerNorm2((hidden_states, conditional_emb))
ffn_output = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(ffn_output)
return hidden_states
class T5PositionWiseFeedForward(PositionWiseFeedForward):
'''参考transformer包: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
'''
def __init__(self, hidden_size, intermediate_size, **kwargs):
super().__init__(hidden_size, intermediate_size, **kwargs)
self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=False)
self.intermediateDense1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
# x shape: (batch size, seq len, hidden_size)
x_gelu = self.intermediate_act_fn(self.intermediateDense(x))
x_linear = self.intermediateDense1(x)
x = x_gelu * x_linear
if self.is_dropout:
x = self.dropout(x)
# x shape: (batch size, seq len, intermediate_size)
x = self.outputDense(x)
# x shape: (batch size, seq len, hidden_size)
return x
class XlnetLayer(BertLayer):
'''Transformer_XL层
顺序为: Attention --> Add --> LayerNorm --> Feed Forward --> Add --> LayerNorm
'''
def __init__(self, hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act, **kwargs):
super().__init__(hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act, **kwargs)
self.pre_lnorm = kwargs.get('pre_lnorm')
# multiattn层无bias
self.multiHeadAttention = self.RelPartialLearnableMultiHeadAttn(hidden_size, num_attention_heads, attention_probs_dropout_prob, bias=False, **kwargs)
def forward(self, hidden_states, segment_ids, pos_emb, attention_mask, mems_i, conditional_emb=None):
# 拼接mems和query,mems_i: [btz, m_len, hdsz], w: [btz, q_len, hdsz] = [btz, k_len, hdsz]
hidden_states_cat = torch.cat([mems_i, hidden_states], 1) if mems_i is not None else hidden_states
# Attn
if self.pre_lnorm:
hidden_states_cat = self.layerNorm1((hidden_states_cat, conditional_emb))
self_attn_output = self.multiHeadAttention(hidden_states, hidden_states_cat, pos_emb, attention_mask, segment_ids)
hidden_states = hidden_states + self.dropout1(self_attn_output)
if not self.pre_lnorm: # post_lnorm
hidden_states = self.layerNorm1((hidden_states, conditional_emb))
# FFN
x = self.layerNorm2((hidden_states, conditional_emb)) if self.pre_lnorm else hidden_states
self_attn_output2 = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(self_attn_output2)
if not self.pre_lnorm: # post_lnorm
hidden_states = self.layerNorm2((hidden_states, conditional_emb))
return hidden_states
class RelPartialLearnableMultiHeadAttn(MultiHeadAttentionLayer):
'''Transformer_XL式相对位置编码, 这里修改成了MultiHeadAttentionLayer的batch_first代码格式
'''
def __init__(self, *args, r_w_bias=None, r_r_bias=None, r_s_bias=None, **kwargs):
super().__init__(*args, **kwargs)
segment_vocab_size = kwargs.get('segment_vocab_size')
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局内容偏置
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局位置偏置
if segment_vocab_size > 0:
self.r_s_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局segment偏置
else: # 所有层公用一个
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
self.r_s_bias = r_s_bias
if segment_vocab_size > 0:
# self.seg_embed = nn.Embedding(segment_vocab_size, self.hidden_size)
self.seg_embed = nn.Parameter(torch.FloatTensor(segment_vocab_size, self.num_attention_heads, self.attention_head_size))
self.r = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias)
self.rel_shift_opt = kwargs.get('rel_shift_opt')
@staticmethod
def rel_shift(x, zero_triu=False):
'''transformer_xl使用, 向左shift让右上角都是0, 对角线是同一个值, x: [btz, n_head, q_len, k_len]
'''
q_len, k_len = x.size(2), x.size(-1)
zero_pad = torch.zeros((*x.size()[:2], q_len, 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], k_len + 1, q_len)
x = x_padded[:,:,1:,:].view_as(x)
if zero_triu:
ones = torch.ones((q_len, k_len), device=x.device)
x = x * torch.tril(ones, k_len - q_len)[None,None,:,:]
return x
@staticmethod
def rel_shift_bnij(x, klen=-1):
''' xlnet使用
'''
x_size = x.shape
x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x
def forward(self, w, cat, r, attention_mask=None, seg_mat=None):
# w: 词向量[btz, q_len, hdsz], cat: w和mem_i拼接后向量[btz, k_len, hdsz], r:相对位置向量[r_len, hdsz]
qlen, rlen, bsz = w.size(1), r.size(0), w.size(0)
mixed_query_layer = self.q(cat)[:, -qlen:, :] # 仅取用query部分,不适用mem部分
mixed_key_layer = self.k(cat)
mixed_value_layer = self.v(cat)
w_head_q = self.transpose_for_scores(mixed_query_layer) # [btz, n_head, q_len, d_head]
w_head_k = self.transpose_for_scores(mixed_key_layer) # [btz, n_head, k_len, d_head]
w_head_v = self.transpose_for_scores(mixed_value_layer) # [btz, n_head, k_len, d_head]
r_head_k = self.r(r) # [hdsz, nhead*headsize] = [r_len, 1, nhead*headsize]
r_head_k = r_head_k.view(rlen, self.num_attention_heads, self.attention_head_size) # rlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias.unsqueeze(1) # [btz, n_head, q_len, d_head]
AC = torch.einsum('bnid,bnjd->bnij', (rw_head_q, w_head_k)) # [btz, n_head, q_len, k_len]
rr_head_q = w_head_q + self.r_r_bias.unsqueeze(1) # [btz, n_head, q_len, d_head]
BD = torch.einsum('bnid,jnd->bnij', (rr_head_q, r_head_k)) # [btz, n_head, q_len, k_len]
BD = self.rel_shift_bnij(BD, klen=AC.shape[3]) if self.rel_shift_opt == 'xlnet' else self.rel_shift(BD)
if hasattr(self, 'seg_embed') and (self.r_r_bias is not None):
# # 之前的方式,需要配合Embedding,以及load_variable和variable_mapping,显存容易爆炸
# w_head_s = self.seg_embed(seg_mat) # [btz, q_len, klen, hdsz]
# w_head_s = w_head_s.reshape(*w_head_s.shape[:3], self.num_attention_heads, self.attention_head_size)
# rs_head_q = w_head_q + self.r_s_bias.unsqueeze(1)
# EF = torch.einsum('bnid,bijnd->bnij', (rs_head_q, w_head_s)) # [btz, n_head, q_len, k_len]
seg_mat = F.one_hot(seg_mat, 2).float()
EF = torch.einsum("bnid,snd->ibns", w_head_q + self.r_s_bias.unsqueeze(1), self.seg_embed)
EF = torch.einsum("bijs,ibns->bnij", seg_mat, EF)
else:
EF = 0
# # [btz, n_head, q_len, k_len]
attention_scores = AC + BD + EF
if self.attention_scale:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
#### compute attention probability
if attention_mask is not None and attention_mask.any().item():
# attention_mask = (1.0 - attention_mask) * -10000.0
# attention_scores = attention_scores + attention_mask # 这里修改了下,原有的-10000不够接近-inf
attention_mask = (1.0 - attention_mask)
attention_scores = attention_scores.float().masked_fill(attention_mask.bool(), -1e30).type_as(attention_mask)
# [btz, n_head, q_len, k_len]
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, w_head_v) # [batch_size, num_attention_heads, query_len, attention_head_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# 是否返回attention scores
if self.return_attention_scores:
# 这里返回的attention_scores没有经过softmax, 可在外部进行归一化操作
return self.o(context_layer), attention_scores
else:
return self.o(context_layer)
class AdaptiveEmbedding(nn.Module):
'''Transformer_XL的自适应embedding, 实现不同区间使用不同的维度
可以实现如高频词用比如1024或512维,低频词用256或64维, 再用Linear层project到相同的维数
'''
def __init__(self, vocab_size, embedding_size, hidden_size, cutoffs, div_val=1, sample_softmax=False, **kwargs):
super().__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.cutoffs = cutoffs + [vocab_size]
self.div_val = div_val
self.hidden_size = hidden_size
self.emb_scale = hidden_size ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(nn.Embedding(vocab_size, embedding_size, sparse=sample_softmax > 0))
if hidden_size != embedding_size:
self.emb_projs.append(nn.Parameter(torch.FloatTensor(hidden_size, embedding_size)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = embedding_size // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(hidden_size, d_emb_i)))
def forward(self, token_ids):
if self.div_val == 1: # 仅有一个embedding
embed = self.emb_layers[0](token_ids) # [btz, seq_len, embedding_size]
if self.hidden_size != self.embedding_size:
embed = nn.functional.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = token_ids.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.hidden_size], dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed_shape = token_ids.size() + (self.hidden_size,)
embed = emb_flat.view(embed_shape)
embed.mul_(self.emb_scale)
return embed
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, *args):
return args[0]
class XlnetPositionsEncoding(nn.Module):
'''Xlnet, transformer_xl使用的相对位置编码
和SinusoidalPositionEncoding区别是一个是间隔排列, 一个是前后排列
'''
def __init__(self, embedding_size):
super().__init__()
self.demb = embedding_size
inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_size, 2.0) / embedding_size))
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb
class RelativePositionsEncoding(nn.Module):
"""nezha用的google相对位置编码
来自论文:https://arxiv.org/abs/1803.02155
"""
def __init__(self, qlen, klen, embedding_size, max_relative_position=127):
super(RelativePositionsEncoding, self).__init__()
# 生成相对位置矩阵
vocab_size = max_relative_position * 2 + 1
distance_mat = torch.arange(klen)[None, :] - torch.arange(qlen)[:, None] # 列数-行数, [query_len, key_len]
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
final_mat = distance_mat_clipped + max_relative_position
# sinusoid_encoding编码的位置矩阵
embeddings_table = get_sinusoid_encoding_table(vocab_size, embedding_size)
# 实现方式1
# flat_relative_positions_matrix = final_mat.view(-1)
# one_hot_relative_positions_matrix = torch.nn.functional.one_hot(flat_relative_positions_matrix, num_classes=vocab_size).float()
# position_embeddings = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
# my_shape = list(final_mat.size())
# my_shape.append(embedding_size)
# position_embeddings = position_embeddings.view(my_shape)
# 实现方式2
# position_embeddings = take_along_dim(embeddings_table, final_mat.flatten().unsqueeze(1), dim=0)
# position_embeddings = position_embeddings.reshape(*final_mat.shape, embeddings_table.shape[-1]) # [seq_len, seq_len, hdsz]
# self.register_buffer('position_embeddings', position_embeddings)
# 实现方式3
position_embeddings = nn.Embedding.from_pretrained(embeddings_table, freeze=True)(final_mat)
self.register_buffer('position_embeddings', position_embeddings)
def forward(self, qlen, klen):
return self.position_embeddings[:qlen, :klen, :]
class RelativePositionsEncodingT5(nn.Module):
"""Google T5的相对位置编码
来自论文:https://arxiv.org/abs/1910.10683
"""
def __init__(self, qlen, klen, relative_attention_num_buckets, is_decoder=False):
super(RelativePositionsEncodingT5, self).__init__()
# 生成相对位置矩阵
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
relative_position = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=not is_decoder,
num_buckets=relative_attention_num_buckets,
)
self.register_buffer('relative_position', relative_position)
def forward(self, qlen, klen):
return self.relative_position[:qlen, :klen]
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
'''直接来源于transformer
'''
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
class SinusoidalPositionEncoding(nn.Module):
"""定义Sin-Cos位置Embedding
"""
def __init__(self, max_position, embedding_size):
super(SinusoidalPositionEncoding, self).__init__()
self.position_embeddings = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(max_position, embedding_size), freeze=True)
def forward(self, position_ids):
return self.position_embeddings(position_ids)
class RoPEPositionEncoding(nn.Module):
"""旋转式位置编码: https://kexue.fm/archives/8265
"""
def __init__(self, max_position, embedding_size):
super(RoPEPositionEncoding, self).__init__()
position_embeddings = get_sinusoid_encoding_table(max_position, embedding_size) # [seq_len, hdsz]
cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1)
sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1)
# register_buffer是为了最外层model.to(device),不用内部指定device
self.register_buffer('cos_position', cos_position)
self.register_buffer('sin_position', sin_position)
def forward(self, qw, seq_dim=-2):
# 默认最后两个维度为[seq_len, hdsz]
seq_len = qw.shape[seq_dim]
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw)
return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len]
class CRF(nn.Module):
'''Conditional random field: https://github.com/lonePatient/BERT-NER-Pytorch/blob/master/models/layers/crf.py
'''
def __init__(self, num_tags: int, init_transitions: Optional[List[np.ndarray]] = None, freeze=False) -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
self.num_tags = num_tags
if (init_transitions is None) and (not freeze):
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
nn.init.uniform_(self.transitions, -0.1, 0.1)
elif init_transitions is not None:
transitions = torch.tensor(init_transitions[0], dtype=torch.float)
start_transitions = torch.tensor(init_transitions[1], dtype=torch.float)
end_transitions = torch.tensor(init_transitions[2], dtype=torch.float)
if not freeze:
self.transitions = nn.Parameter(transitions)
self.start_transitions = nn.Parameter(start_transitions)
self.end_transitions = nn.Parameter(end_transitions)
else:
self.register_buffer('transitions', transitions)
self.register_buffer('start_transitions', start_transitions)
self.register_buffer('end_transitions', end_transitions)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
def forward(self, emissions: torch.Tensor, mask: torch.ByteTensor,
tags: torch.LongTensor, reduction: str = 'mean') -> torch.Tensor:
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
emissions: [btz, seq_len, num_tags]
mask: [btz, seq_len]
tags: [btz, seq_len]
"""
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
if mask.dtype != torch.uint8:
mask = mask.byte()
self._validate(emissions, tags=tags, mask=mask)
# shape: (batch_size,)
numerator = self._compute_score(emissions, tags, mask)
# shape: (batch_size,)
denominator = self._compute_normalizer(emissions, mask)
# shape: (batch_size,)
llh = denominator - numerator
if reduction == 'none':
return llh
if reduction == 'sum':
return llh.sum()
if reduction == 'mean':
return llh.mean()
return llh.sum() / mask.float().sum()
def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None,
nbest: Optional[int] = None, pad_tag: Optional[int] = None) -> List[List[List[int]]]:
"""Find the most likely tag sequence using Viterbi algorithm.
"""
if nbest is None:
nbest = 1
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device)
if mask.dtype != torch.uint8:
mask = mask.byte()
self._validate(emissions, mask=mask)
best_path = self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
return best_path[0] if nbest == 1 else best_path
def _validate(self, emissions: torch.Tensor, tags: Optional[torch.LongTensor] = None,
mask: Optional[torch.ByteTensor] = None) -> None:
if emissions.dim() != 3:
raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
if emissions.size(2) != self.num_tags:
raise ValueError(f'expected last dimension of emissions is {self.num_tags}, '
f'got {emissions.size(2)}')
if tags is not None:
if emissions.shape[:2] != tags.shape:
raise ValueError('the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
if mask is not None:
if emissions.shape[:2] != mask.shape:
raise ValueError('the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
no_empty_seq_bf = mask[:, 0].all()
if not no_empty_seq_bf:
raise ValueError('mask of the first timestep must all be on')
def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (batch_size, seq_length, num_tags)
# tags: (batch_size, seq_length)
# mask: (batch_size, seq_length)
batch_size, seq_length = tags.shape
mask = mask.float()
# Start transition score and first emission
# shape: (batch_size,)
score = self.start_transitions[tags[:, 0]]
score += emissions[torch.arange(batch_size), 0, tags[:, 0]]
for i in range(1, seq_length):
# Transition score to next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += self.transitions[tags[:, i - 1], tags[:, i]] * mask[:, i]
# Emission score for next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += emissions[torch.arange(batch_size), i, tags[:, i]] * mask[:, i]
# End transition score
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=1) - 1
# shape: (batch_size,)
last_tags = tags[torch.arange(batch_size), seq_ends]
# shape: (batch_size,)
score += self.end_transitions[last_tags]
return score
def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (batch_size, seq_length, num_tags)
# mask: (batch_size, seq_length)
seq_length = emissions.size(1)
# Start transition score and first emission; score has size of
# (batch_size, num_tags) where for each batch, the j-th column stores
# the score that the first timestep has tag j
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[:, 0]
for i in range(1, seq_length):
# Broadcast score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[:, i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the sum of scores of all
# possible tag sequences so far that end with transitioning from tag i to tag j
# and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Sum over all possible current tags, but we're in score space, so a sum
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
# all possible tag sequences so far, that end in tag i
# shape: (batch_size, num_tags)
next_score = torch.logsumexp(next_score, dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# shape: (batch_size, num_tags)
score = torch.where(mask[:, i].unsqueeze(1).bool(), next_score, score)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Sum (log-sum-exp) over all possible tags
# shape: (batch_size,)
return torch.logsumexp(score, dim=1)
def _viterbi_decode_nbest(self, emissions: torch.FloatTensor, mask: torch.ByteTensor,
nbest: int, pad_tag: Optional[int] = None) -> List[List[List[int]]]:
# emissions: (batch_size, seq_length, num_tags)
# mask: (batch_size, seq_length)
# return: (nbest, batch_size, seq_length)
if pad_tag is None:
pad_tag = 0
device = emissions.device
batch_size, seq_length = mask.shape
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[:, 0]
history_idx = torch.zeros((batch_size, seq_length, self.num_tags, nbest), dtype=torch.long, device=device)
oor_idx = torch.zeros((batch_size, self.num_tags, nbest), dtype=torch.long, device=device)
oor_tag = torch.full((batch_size, seq_length, nbest), pad_tag, dtype=torch.long, device=device)
# - score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
# with tag j
# - history_idx saves where the best tags candidate transitioned from; this is used
# when we trace back the best tag sequence
# - oor_idx saves the best tags candidate transitioned from at the positions
# where mask is 0, i.e. out of range (oor)
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
# for every possible next tag
for i in range(1, seq_length):
if i == 1:
broadcast_score = score.unsqueeze(-1)
broadcast_emission = emissions[:, i].unsqueeze(1)
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emission
else:
broadcast_score = score.unsqueeze(-1)
broadcast_emission = emissions[:, i].unsqueeze(1).unsqueeze(2)
# shape: (batch_size, num_tags, nbest, num_tags)
next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission
# Find the top `nbest` maximum score over all possible current tag
# shape: (batch_size, nbest, num_tags)
next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)
if i == 1:
score = score.unsqueeze(-1).expand(-1, -1, nbest)
indices = indices * nbest
# convert to shape: (batch_size, num_tags, nbest)
next_score = next_score.transpose(2, 1)
indices = indices.transpose(2, 1)
# Set score to the next score if this timestep is valid (mask == 1)
# and save the index that produces the next score
# shape: (batch_size, num_tags, nbest)
score = torch.where(mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), next_score, score)
indices = torch.where(mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), indices, oor_idx)
history_idx[:, i - 1] = indices
# End transition score shape: (batch_size, num_tags, nbest)
end_score = score + self.end_transitions.unsqueeze(-1)
_, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=1) - 1
# insert the best tag at each sequence end (last position with mask == 1)
history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
# The most probable path for each sequence
best_tags_arr = torch.zeros((batch_size, seq_length, nbest), dtype=torch.long, device=device)
best_tags = torch.arange(nbest, dtype=torch.long, device=device).view(1, -1).expand(batch_size, -1)
for idx in range(seq_length - 1, -1, -1):
best_tags = torch.gather(history_idx[:, idx].view(batch_size, -1), 1, best_tags)
best_tags_arr[:, idx] = torch.div(best_tags.data.view(batch_size, -1), nbest, rounding_mode='floor')
return torch.where(mask.unsqueeze(-1).bool(), best_tags_arr, oor_tag).permute(2, 0, 1)
class BERT_WHITENING():
def __init__(self):
self.kernel = None
self.bias = None
def compute_kernel_bias(self, sentence_vec):
'''bert-whitening的torch实现
'''
vecs = torch.cat(sentence_vec, dim=0)
self.bias = -vecs.mean(dim=0, keepdims=True)
cov = torch.cov(vecs.T) # 协方差
u, s, vh = torch.linalg.svd(cov)
W = torch.matmul(u, torch.diag(s**0.5))
self.kernel = torch.linalg.inv(W.T)
def save_whiten(self, path):
whiten = {'kernel': self.kernel, 'bias': self.bias}
torch.save(path, whiten)
def load_whiten(self, path):
whiten = torch.load(path)
self.kernel = whiten['kernel']
self.bias = whiten['bias']
def transform_and_normalize(self, vecs):
"""应用变换,然后标准化
"""
if not (self.kernel is None or self.bias is None):
vecs = (vecs + self.bias).mm(self.kernel)
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
class GlobalPointer(nn.Module):
"""全局指针模块
将序列的每个(start, end)作为整体来进行判断
参考:https://kexue.fm/archives/8373
"""
def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True):
super().__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
self.tril_mask = tril_mask
self.RoPE = RoPE
self.dense = nn.Linear(hidden_size, heads * head_size * 2, bias=use_bias)
if self.RoPE:
self.position_embedding = RoPEPositionEncoding(max_len, head_size)
def forward(self, inputs, mask=None):
''' inputs: [..., hdsz]
mask: [bez, seq_len], padding部分为0
'''
sequence_output = self.dense(inputs) # [..., heads*head_size*2]
sequence_output = torch.stack(torch.chunk(sequence_output, self.heads, dim=-1), dim=-2) # [..., heads, head_size*2]
qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., heads, head_size]
# ROPE编码
if self.RoPE:
qw = self.position_embedding(qw)
kw = self.position_embedding(kw)
# 计算内积
logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw) # [btz, heads, seq_len, seq_len]
# 排除padding
if mask is not None:
attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1]
attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len]
logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf'))
logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf'))
# 排除下三角
if self.tril_mask:
logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12
# scale返回
return logits / self.head_size**0.5
class EfficientGlobalPointer(nn.Module):
"""更加参数高效的GlobalPointer
参考:https://kexue.fm/archives/8877
"""
def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True):
super().__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
self.tril_mask = tril_mask
self.RoPE = RoPE
self.p_dense = nn.Linear(hidden_size, head_size * 2, bias=use_bias)
self.q_dense = nn.Linear(head_size * 2, heads * 2, bias=use_bias)
if self.RoPE:
self.position_embedding = RoPEPositionEncoding(max_len, head_size)
def forward(self, inputs, mask=None):
''' inputs: [..., hdsz]
mask: [bez, seq_len], padding部分为0
'''
sequence_output = self.p_dense(inputs) # [..., head_size*2]
qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., head_size]
# ROPE编码
if self.RoPE:
qw = self.position_embedding(qw)
kw = self.position_embedding(kw)
# 计算内积
logits = torch.einsum('bmd,bnd->bmn', qw, kw) / self.head_size**0.5 # [btz, seq_len, seq_len], 是否是实体的打分
bias_input = self.q_dense(sequence_output) # [..., heads*2]
bias = torch.stack(torch.chunk(bias_input, self.heads, dim=-1), dim=-2).transpose(1,2) # [btz, heads, seq_len, 2]
logits = logits.unsqueeze(1) + bias[..., :1] + bias[..., 1:].transpose(2, 3) # [btz, heads, seq_len, seq_len]
# 排除padding
if mask is not None:
attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1]
attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len]
logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf'))
logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf'))
# 排除下三角
if self.tril_mask:
logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12
return logits
class TplinkerHandshakingKernel(nn.Module):
'''Tplinker的HandshakingKernel实现
'''
def __init__(self, hidden_size, shaking_type, inner_enc_type=''):
super().__init__()
self.shaking_type = shaking_type
if shaking_type == "cat":
self.combine_fc = nn.Linear(hidden_size * 2, hidden_size)
elif shaking_type == "cat_plus":
self.combine_fc = nn.Linear(hidden_size * 3, hidden_size)
elif shaking_type == "cln":
self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
elif shaking_type == "cln_plus":
self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
self.inner_context_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
self.inner_enc_type = inner_enc_type
if inner_enc_type == "mix_pooling":
self.lamtha = nn.Parameter(torch.rand(hidden_size))
elif inner_enc_type == "lstm":
self.inner_context_lstm = nn.LSTM(hidden_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True)
# 自行实现的用torch.gather方式来做,避免循环,目前只实现了cat方式
# tag_ids = [(i, j) for i in range(maxlen) for j in range(maxlen) if j >= i]
# gather_idx = torch.tensor(tag_ids, dtype=torch.long).flatten()[None, :, None]
# self.register_buffer('gather_idx', gather_idx)
def enc_inner_hiddens(self, seq_hiddens, inner_enc_type="lstm"):
# seq_hiddens: (batch_size, seq_len, hidden_size)
def pool(seqence, pooling_type):
if pooling_type == "mean_pooling":
pooling = torch.mean(seqence, dim = -2)
elif pooling_type == "max_pooling":
pooling, _ = torch.max(seqence, dim = -2)
elif pooling_type == "mix_pooling":
pooling = self.lamtha * torch.mean(seqence, dim = -2) + (1 - self.lamtha) * torch.max(seqence, dim = -2)[0]
return pooling
if "pooling" in inner_enc_type:
inner_context = torch.stack([pool(seq_hiddens[:, :i+1, :], inner_enc_type) for i in range(seq_hiddens.size()[1])], dim = 1)
elif inner_enc_type == "lstm":
inner_context, _ = self.inner_context_lstm(seq_hiddens)
return inner_context
def forward(self, seq_hiddens):
'''
seq_hiddens: (batch_size, seq_len, hidden_size)
return:
shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)
'''
seq_len = seq_hiddens.size()[-2]
shaking_hiddens_list = []
for ind in range(seq_len):
hidden_each_step = seq_hiddens[:, ind, :]
visible_hiddens = seq_hiddens[:, ind:, :] # ind: only look back
repeat_hiddens = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1)
if self.shaking_type == "cat":
shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens], dim = -1)
shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))
elif self.shaking_type == "cat_plus":
inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)
shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens, inner_context], dim = -1)
shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))
elif self.shaking_type == "cln":
shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens])
elif self.shaking_type == "cln_plus":
inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)
shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens])
shaking_hiddens = self.inner_context_cln([shaking_hiddens, inner_context])
shaking_hiddens_list.append(shaking_hiddens)
long_shaking_hiddens = torch.cat(shaking_hiddens_list, dim = 1)
return long_shaking_hiddens
# def handshaking_kernel(self, last_hidden_state):
# '''获取(0,0),(0,1),...,(99,99))对应的序列id
# '''
# btz, _, hdsz = last_hidden_state.shape
# gather_idx = self.gather_idx.repeat(btz, 1, hdsz)
# concat_hidden_states = torch.gather(last_hidden_state, dim=1, index=gather_idx) # [btz, pair_len*2, hdsz]
# concat_hidden_states = concat_hidden_states.reshape(btz, -1, 2, hdsz) # concat方式 [btz, pair_len, 2, hdsz]
# shaking_hiddens = torch.cat(torch.chunk(concat_hidden_states, chunks=2, dim=-2), dim=-1).squeeze(-2) # [btz, pair_len, hdsz*2]
# return shaking_hiddens
class MixUp(nn.Module):
'''mixup方法实现
method: embed, encoder分别表示在embedding和encoder层面做mixup, None表示mix后续处理, hidden表示对隐含层做mixup
'''
def __init__(self, method='encoder', alpha=1.0, layer_mix=None):
super().__init__()
assert method in {'embed', 'encoder', 'hidden', None}
self.method = method
self.alpha = alpha
self.perm_index = None
self.lam = 0
self.layer_mix = layer_mix # 需要mix的隐含层index
def get_perm(self, inputs):
if isinstance(inputs, torch.Tensor):
return inputs[self.perm_index]
elif isinstance(inputs, (list, tuple)):
return [inp[self.perm_index] if isinstance(inp, torch.Tensor) else inp for inp in inputs]
def mix_up(self, output, output1):
if isinstance(output, torch.Tensor):
return self.lam * output + (1.0-self.lam) * output1
elif isinstance(output, (list, tuple)):
output_final = []
for i in range(len(output)):
if output[i] is None: # conditional_emb=None
output_final.append(output[i])
elif (not output[i].requires_grad) and (output[i].dtype in {torch.long, torch.int}):
# 不是embedding形式的
output_final.append(torch.max(output[i], output1[i]))
else:
output_final.append(self.lam * output[i] + (1.0-self.lam) * output1[i])
return output_final
else:
raise ValueError('Illegal model output')
def encode(self, model, inputs):
batch_size = inputs[0].shape[0]
device = inputs[0].device
self.lam = np.random.beta(self.alpha, self.alpha)
self.perm_index = torch.randperm(batch_size).to(device)
if self.method is None:
output = model(inputs)
output1 = self.get_perm(output)
return [output, output1]
elif self.method == 'encoder':
output = model(inputs)
output1 = self.get_perm(output)
output_final = self.mix_up(output, output1)
elif self.method == 'embed':
output = model.apply_embeddings(inputs)
output1 = self.get_perm(output)
output_final = self.mix_up(output, output1)
# Main
output_final = model.apply_main_layers(output_final)
# Final
output_final = model.apply_final_layers(output_final)
elif self.method == 'hidden':
if self.layer_mix is None:
# 这里暂时只考虑encoderLayer, 不考虑decoderLayer和seq2seq模型结构
try:
layer_mix = random.randint(0, len(model.encoderLayer))
except:
warnings.warn('LayerMix random failded')
layer_mix = 0
else:
layer_mix = self.layer_mix
def apply_on_layer_end(l_i, output):
if l_i == layer_mix:
output1 = self.get_perm(output)
return self.mix_up(output, output1)
else:
return output
model.apply_on_layer_end = apply_on_layer_end
output_final = model(inputs)
return output_final
def forward(self, criterion, y_pred, y_true):
'''计算loss
'''
y_true1 = y_true[self.perm_index]
return self.lam * criterion(y_pred, y_true) + (1 - self.lam) * criterion(y_pred, y_true1)
\ No newline at end of file
from ast import arg
from tracemalloc import start
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class FocalLoss(nn.Module):
'''Multi-class Focal loss implementation'''
def __init__(self, gamma=2, weight=None,ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index=ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1-pt)**self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
return loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean',ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum':
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean':
loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
ignore_index=self.ignore_index)
class MultilabelCategoricalCrossentropy(nn.Module):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1, 1表示对应的类为目标类,0表示对应的类为非目标类。
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax!预测
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解本文。
参考:https://kexue.fm/archives/7359
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, y_pred, y_true):
""" y_true ([Tensor]): [..., num_classes]
y_pred ([Tensor]): [..., num_classes]
"""
y_pred = (1-2*y_true) * y_pred
y_pred_pos = y_pred - (1-y_true) * 1e12
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = torch.cat([y_pred_pos, torch.zeros_like(y_pred_pos[..., :1])], dim=-1)
y_pred_neg = torch.cat([y_pred_neg, torch.zeros_like(y_pred_neg[..., :1])], dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
return (pos_loss + neg_loss).mean()
class SparseMultilabelCategoricalCrossentropy(nn.Module):
"""稀疏版多标签分类的交叉熵
说明:
1. y_true.shape=[..., num_positive],
y_pred.shape=[..., num_classes];
2. 请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax;
3. 预测阶段则输出y_pred大于0的类;
4. 详情请看:https://kexue.fm/archives/7359 。
"""
def __init__(self, mask_zero=False, epsilon=1e-7, **kwargs):
super().__init__(**kwargs)
self.mask_zero = mask_zero
self.epsilon = epsilon
def forward(self, y_pred, y_true):
zeros = torch.zeros_like(y_pred[..., :1])
y_pred = torch.cat([y_pred, zeros], dim=-1)
if self.mask_zero:
infs = zeros + float('inf')
y_pred = torch.cat([infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
y_pos_1 = torch.cat([y_pos_2, zeros], dim=-1)
if self.mask_zero:
y_pred = torch.cat([-infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
pos_loss = torch.logsumexp(-y_pos_1, dim=-1)
all_loss = torch.logsumexp(y_pred, dim=-1) # a
aux_loss = torch.logsumexp(y_pos_2, dim=-1) - all_loss # b-a
aux_loss = torch.clamp(1 - torch.exp(aux_loss), self.epsilon, 1) # 1-exp(b-a)
neg_loss = all_loss + torch.log(aux_loss) # a + log[1-exp(b-a)]
return pos_loss + neg_loss
class ContrastiveLoss(nn.Module):
"""对比损失:减小正例之间的距离,增大正例和反例之间的距离
公式:labels * distance_matrix.pow(2) + (1-labels)*F.relu(margin-distance_matrix).pow(2)
https://www.sbert.net/docs/package_reference/losses.html
"""
def __init__(self, margin=0.5, size_average=True, online=False):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.size_average = size_average
self.online = online
def forward(self, distances, labels, pos_id=1, neg_id=0):
if not self.online:
losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
return losses.mean() if self.size_average else losses.sum()
else:
negs = distances[labels == neg_id]
poss = distances[labels == pos_id]
# select hard positive and hard negative pairs
negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]
positive_loss = positive_pairs.pow(2).sum()
negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
return positive_loss + negative_loss
class RDropLoss(nn.Module):
'''R-Drop的Loss实现,官方项目:https://github.com/dropreg/R-Drop
'''
def __init__(self, alpha=4, rank='adjacent'):
super().__init__()
self.alpha = alpha
# 支持两种方式,一种是奇偶相邻排列,一种是上下排列
assert rank in {'adjacent', 'updown'}, "rank kwarg only support 'adjacent' and 'updown' "
self.rank = rank
self.loss_sup = nn.CrossEntropyLoss()
self.loss_rdrop = nn.KLDivLoss(reduction='none')
def forward(self, *args):
'''支持两种方式: 一种是y_pred, y_true, 另一种是y_pred1, y_pred2, y_true
'''
assert len(args) in {2, 3}, 'RDropLoss only support 2 or 3 input args'
# y_pred是1个Tensor
if len(args) == 2:
y_pred, y_true = args
loss_sup = self.loss_sup(y_pred, y_true) # 两个都算
if self.rank == 'adjacent':
y_pred1 = y_pred[1::2]
y_pred2 = y_pred[::2]
elif self.rank == 'updown':
half_btz = y_true.shape[0] // 2
y_pred1 = y_pred[:half_btz]
y_pred2 = y_pred[half_btz:]
# y_pred是两个tensor
else:
y_pred1, y_pred2, y_true = args
loss_sup = self.loss_sup(y_pred1, y_true)
loss_rdrop1 = self.loss_rdrop(F.log_softmax(y_pred1, dim=-1), F.softmax(y_pred2, dim=-1))
loss_rdrop2 = self.loss_rdrop(F.log_softmax(y_pred2, dim=-1), F.softmax(y_pred1, dim=-1))
return loss_sup + torch.mean(loss_rdrop1 + loss_rdrop2) / 4 * self.alpha
class UDALoss(nn.Module):
'''UDALoss,使用时候需要继承一下,因为forward需要使用到global_step和total_steps
https://arxiv.org/abs/1904.12848
'''
def __init__(self, tsa_schedule=None, total_steps=None, start_p=0, end_p=1, return_all_loss=True):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.loss_unsup = nn.KLDivLoss(reduction='batchmean')
self.tsa_schedule = tsa_schedule
self.start = start_p
self.end = end_p
if self.tsa_schedule:
assert self.tsa_schedule in {'linear_schedule', 'exp_schedule', 'log_schedule'}, 'tsa_schedule config illegal'
self.return_all_loss = return_all_loss
def forward(self, y_pred, y_true_sup, global_step, total_steps):
sup_size = y_true_sup.size(0)
unsup_size = (y_pred.size(0) - sup_size) // 2
# 有监督部分, 用交叉熵损失
y_pred_sup = y_pred[:sup_size]
if self.tsa_schedule is None:
loss_sup = self.loss_sup(y_pred_sup, y_true_sup)
else: # 使用tsa来去掉预测概率较高的有监督样本
threshold = self.get_tsa_threshold(self.tsa_schedule, global_step, total_steps, self.start, self.end)
true_prob = torch.gather(F.softmax(y_pred_sup, dim=-1), dim=1, index=y_true_sup[:, None])
sel_rows = true_prob.lt(threshold).sum(dim=-1).gt(0) # 仅保留小于阈值的样本
loss_sup = self.loss_sup(y_pred_sup[sel_rows], y_true_sup[sel_rows]) if sel_rows.sum() > 0 else 0
# 无监督部分,这里用KL散度,也可以用交叉熵
y_true_unsup = y_pred[sup_size:sup_size+unsup_size]
y_true_unsup = F.softmax(y_true_unsup.detach(), dim=-1)
y_pred_unsup = F.log_softmax(y_pred[sup_size+unsup_size:], dim=-1)
loss_unsup = self.loss_unsup(y_pred_unsup, y_true_unsup)
if self.return_all_loss:
return loss_sup + loss_unsup, loss_sup, loss_unsup
else:
return loss_sup + loss_unsup
@ staticmethod
def get_tsa_threshold(schedule, global_step, num_train_steps, start, end):
training_progress = global_step / num_train_steps
if schedule == "linear_schedule":
threshold = training_progress
elif schedule == "exp_schedule":
scale = 5
threshold = math.exp((training_progress - 1) * scale)
elif schedule == "log_schedule":
scale = 5
threshold = 1 - math.exp((-training_progress) * scale)
return threshold * (end - start) + start
class TemporalEnsemblingLoss(nn.Module):
'''TemporalEnsembling的实现,思路是在监督loss的基础上,增加一个mse的一致性损失loss
官方项目:https://github.com/s-laine/tempens
pytorch第三方实现:https://github.com/ferretj/temporal-ensembling
使用的时候,train_dataloader的shffle必须未False
'''
def __init__(self, epochs, max_val=10.0, ramp_up_mult=-5.0, alpha=0.5, max_batch_num=100, hist_device='cpu'):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.max_epochs = epochs
self.max_val = max_val
self.ramp_up_mult = ramp_up_mult
self.alpha = alpha
self.max_batch_num = max_batch_num # 设置未None表示记录全部数据历史,数据量大时耗资源
self.hist_unsup = [] # 历史无监督logit
self.hist_sup = [] # 历史监督信息logit
self.hist_device = hist_device
self.hist_input_y = [] # 历史监督标签y
assert (self.alpha >= 0) & (self.alpha < 1) # 等于1的时候upata写分母为0
def forward(self, y_pred_sup, y_pred_unsup, y_true_sup, epoch, bti):
self.same_batch_check(y_pred_sup, y_pred_unsup, y_true_sup, bti)
if (self.max_batch_num is None) or (bti < self.max_batch_num):
self.init_hist(bti, y_pred_sup, y_pred_unsup) # 初始化历史
sup_ratio = float(len(y_pred_sup)) / (len(y_pred_sup) + len(y_pred_unsup)) # 监督样本的比例
w = self.weight_schedule(epoch, sup_ratio)
sup_loss, unsup_loss = self.temporal_loss(y_pred_sup, y_pred_unsup, y_true_sup, bti)
# 更新
self.hist_unsup[bti] = self.update(self.hist_unsup[bti], y_pred_unsup.detach(), epoch)
self.hist_sup[bti] = self.update(self.hist_sup[bti], y_pred_sup.detach(), epoch)
# if bti == 0: $ 用于检查每个epoch数据顺序是否一致
# print(w, sup_loss.item(), w * unsup_loss.item())
# print(y_true_sup)
return sup_loss + w * unsup_loss, sup_loss, w * unsup_loss
else:
return self.loss_sup(y_pred_sup, y_true_sup)
def same_batch_check(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
'''检测数据的前几个batch必须是一致的, 这里写死是10个
'''
if bti >= 10:
return
if bti >= len(self.hist_input_y):
self.hist_input_y.append(y_true_sup.to(self.hist_device))
else: # 检测
err_msg = 'TemporalEnsemblingLoss requests the same sort dataloader, you may need to set train_dataloader shuffle=False'
assert self.hist_input_y[bti].equal(y_true_sup.to(self.hist_device)), err_msg
def update(self, hist, y_pred, epoch):
'''更新历史logit,利用alpha门控来控制比例
'''
Z = self.alpha * hist.to(y_pred) + (1. -self.alpha) * y_pred
output = Z * (1. / (1. - self.alpha ** (epoch + 1)))
return output.to(self.hist_device)
def weight_schedule(self, epoch, sup_ratio):
max_val = self.max_val * sup_ratio
if epoch == 0:
return 0.
elif epoch >= self.max_epochs:
return max_val
return max_val * np.exp(self.ramp_up_mult * (1. - float(epoch) / self.max_epochs) ** 2)
def temporal_loss(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
# MSE between current and temporal outputs
def mse_loss(out1, out2):
quad_diff = torch.sum((F.softmax(out1, dim=1) - F.softmax(out2, dim=1)) ** 2)
return quad_diff / out1.data.nelement()
sup_loss = self.loss_sup(y_pred_sup, y_true_sup)
# 原来实现是sup和unsup作为一个tensor,整体计算的,这里由于是拆分成两个tensor,因此分开算
unsup_loss = mse_loss(y_pred_unsup, self.hist_unsup[bti].to(y_pred_unsup))
unsup_loss += mse_loss(y_pred_sup, self.hist_sup[bti].to(y_pred_sup))
return sup_loss, unsup_loss
def init_hist(self, bti, y_pred_sup, y_pred_unsup):
if bti >= len(self.hist_sup):
self.hist_sup.append(torch.zeros_like(y_pred_sup).to(self.hist_device))
self.hist_unsup.append(torch.zeros_like(y_pred_unsup).to(self.hist_device))
\ No newline at end of file
import torch
import torch.nn as nn
import copy
import json
import re
from bert4torch.layers import LayerNorm, BertEmbeddings, BertLayer, Identity, T5Layer, GatedAttentionUnit, XlnetLayer
from bert4torch.layers import AdaptiveEmbedding, XlnetPositionsEncoding
from bert4torch.snippets import metric_mapping, search_layer, insert_arguments, delete_arguments, get_kw
from bert4torch.snippets import ProgbarLogger, EarlyStopping, FGM, PGD, VAT, IterDataset, take_along_dim
from bert4torch.activations import get_activation
import warnings
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
# 这里主要是为了外面调用用到
self.global_step, self.local_step, self.total_steps, self.epoch, self.train_dataloader = 0, 0, 0, 0, None
self.callbacks = []
def compile(self, loss, optimizer, scheduler=None, max_grad_norm=None, use_amp=False, metrics=None, adversarial_train={'name': ''}):
'''定义loss, optimizer, metrics, 是否在计算loss前reshape
loss: loss
optimizer: 优化器
scheduler: scheduler
max_grad_norm: 是否使用梯度裁剪, 默认不启用
use_amp: 是否使用混合精度,默认不启用
metrics: 训练过程中需要打印的指标, loss相关指标默认会打印, 目前支持accuracy
'''
self.criterion = loss
self.optimizer = optimizer
self.scheduler = scheduler
self.max_grad_norm = max_grad_norm
self.use_amp = use_amp
if use_amp:
assert adversarial_train['name'] not in {'vat', 'gradient_penalty'}, 'Amp and adversarial_train both run is not supported in current version'
from torch.cuda.amp import autocast
self.autocast = autocast
self.scaler = torch.cuda.amp.GradScaler()
if metrics is None:
metrics = []
self.metrics = ['loss'] + [i for i in metrics if i!='loss']
# 对抗训练
self.adversarial = adversarial_train
self.adversarial_initialize()
def adversarial_initialize(self):
'''对抗训练初始化
'''
assert self.adversarial['name'] in {'', 'fgm', 'pgd', 'vat', 'gradient_penalty'}, 'adversarial_train support fgm, pgd, vat and gradient_penalty mode'
self.adversarial['epsilon'] = self.adversarial.get('epsilon', 1.0)
self.adversarial['emb_name'] = self.adversarial.get('emb_name', 'word_embeddings')
if self.adversarial['name'] == 'fgm':
self.ad_train = FGM(self)
elif self.adversarial['name'] == 'pgd':
self.adversarial['K'] = self.adversarial.get('K', 3) # 步数
self.adversarial['alpha'] = self.adversarial.get('alpha', 0.3) # 学习率
self.ad_train = PGD(self)
elif self.adversarial['name'] == 'gradient_penalty':
pass
elif self.adversarial['name'] == 'vat':
self.adversarial['K'] = self.adversarial.get('K', 3)
self.adversarial['noise_var'] = self.adversarial.get('noise_var', 1e-5) # 噪声的方差
self.adversarial['noise_gamma'] = self.adversarial.get('noise_gamma', 1e-6) # eps
self.adversarial['adv_step_size'] = self.adversarial.get('adv_step_size', 1e-3) # 学习率
self.adversarial['adv_alpha'] = self.adversarial.get('adv_alpha', 1) # 对抗loss的权重
self.adversarial['norm_type'] = self.adversarial.get('norm_type', 'l2') # 归一化方式
self.ad_train = VAT(self, **self.adversarial)
def adversarial_training(self, train_X, train_y, output, loss, loss_detail, grad_accumulation_steps):
'''对抗训练
'''
if self.adversarial['name'] == 'fgm':
self.ad_train.attack(**self.adversarial) # embedding被修改了
output, loss, loss_detail = self.train_step(train_X, train_y, grad_accumulation_steps)
loss.backward() # 反向传播,在正常的grad基础上,累加对抗训练的梯度
# 恢复Embedding的参数, 因为要在正常的embedding上更新参数,而不是增加了对抗扰动后的embedding上更新参数~
self.ad_train.restore(**self.adversarial)
elif self.adversarial['name'] == 'pgd':
self.ad_train.backup_grad() # 备份梯度
for t in range(self.adversarial['K']):
# 在embedding上添加对抗扰动, first attack时备份param.data
self.ad_train.attack(**self.adversarial, is_first_attack=(t==0))
if t != self.adversarial['K']-1:
self.optimizer.zero_grad() # 为了累积扰动而不是梯度
else:
self.ad_train.restore_grad() # 恢复正常的grad
output, loss, loss_detail = self.train_step(train_X, train_y, grad_accumulation_steps)
loss.backward() # 反向传播,在正常的grad基础上,累加对抗训练的梯度
self.ad_train.restore(**self.adversarial) # 恢复embedding参数
# 梯度惩罚
elif self.adversarial['name'] == 'gradient_penalty':
para = search_layer(self, self.adversarial['emb_name'], retrun_first=True)
gp = (para.grad ** 2).sum()
loss += 0.5 * gp * self.adversarial['epsilon']
loss.backward()
# 虚拟对抗训练
elif self.adversarial['name'] == 'vat':
logit = output[0] if isinstance(output, (list, tuple)) else output
adv_loss = self.ad_train.virtual_adversarial_training(train_X, logit)
loss_detail.update({'loss_sup': loss.item(), 'loss_unsup': adv_loss})
loss += (adv_loss if adv_loss else 0)
loss.backward()
return loss, loss_detail
def train_step(self, train_X, train_y, grad_accumulation_steps):
'''forward并返回loss
'''
def args_segmentate(train_X):
'''参数是否展开
'''
if isinstance(train_X, torch.Tensor): # tensor不展开
pass
elif isinstance(self, (BaseModelDP, BaseModelDDP)):
if self.module.forward.__code__.co_argcount >= 3:
return True
elif self.forward.__code__.co_argcount >= 3:
return True
return False
if self.use_amp:
with self.autocast():
output = self.forward(*train_X) if args_segmentate(train_X) else self.forward(train_X)
loss_detail = self.criterion(output, train_y)
else:
output = self.forward(*train_X) if args_segmentate(train_X) else self.forward(train_X)
loss_detail = self.criterion(output, train_y)
if isinstance(loss_detail, torch.Tensor):
loss = loss_detail
loss_detail = {}
elif isinstance(loss_detail, dict):
loss = loss_detail['loss'] # 还存在其他loss,仅用于打印
del loss_detail['loss']
elif isinstance(loss_detail, (tuple, list)):
loss = loss_detail[0]
loss_detail = {f'loss{i}':v for i, v in enumerate(loss_detail[1:], start=1)}
else:
raise ValueError('Return loss only support Tensor/dict/tuple/list format')
# 梯度累积
loss = loss / grad_accumulation_steps if grad_accumulation_steps > 1 else loss
return output, loss, loss_detail
def callback_fun(self, mode, logs={}):
'''统一调用callback, 方便一些判断条件的触发
'''
# 如果是分布式DDP训练,则仅masker_rank可以callback
if isinstance(self, BaseModelDDP) and self.master_rank!=torch.distributed.get_rank():
return
if mode == 'train_begin':
for callback in self.callbacks:
callback.on_train_begin()
elif mode == 'epoch_begin':
for callback in self.callbacks:
callback.on_epoch_begin(self.global_step, self.epoch, logs)
elif mode == 'batch_begin':
for callback in self.callbacks:
callback.on_batch_begin(self.global_step, self.local_step, logs)
elif mode == 'batch_end':
for callback in self.callbacks:
callback.on_batch_end(self.global_step, self.local_step, logs)
elif mode == 'epoch_end':
for callback in self.callbacks:
callback.on_epoch_end(self.global_step, self.epoch, logs)
elif mode == 'train_end':
for callback in self.callbacks:
callback.on_train_end()
elif mode == 'dataloader_end':
for callback in self.callbacks:
callback.on_dataloader_end()
def fit(self, train_dataloader, steps_per_epoch=None, epochs=1, grad_accumulation_steps=1, callbacks=[]):
if isinstance(train_dataloader.dataset, IterDataset):
assert steps_per_epoch is not None, 'IterDataset should specify steps_per_epoch'
steps_per_epoch = len(train_dataloader) if steps_per_epoch is None else steps_per_epoch
self.total_steps = steps_per_epoch * epochs
self.global_step = 0
self.train_dataloader = train_dataloader # 设置为成员变量,可由外部的callbacks进行修改
train_dataloader_iter = iter(self.train_dataloader) # 循环epoch时不重生成
self.callbacks = [ProgbarLogger(epochs, steps_per_epoch, self.metrics)] + (callbacks if isinstance(callbacks, (list, tuple)) else [callbacks])
self.callback_fun('train_begin')
# epoch:当前epoch
# global_step:当前全局训练步数
# local_step: 当前epoch内的训练步数,不同epoch中相同local_step对应的batch数据不一定相同,在steps_per_epoch=None时相同
# bti:在dataloader中的index,不同epoch中相同的bti对应的batch数据一般相同,除非重新生成dataloader
self.bti = 0
for epoch in range(epochs):
self.epoch = epoch
self.callback_fun('epoch_begin')
for local_step in range(steps_per_epoch):
self.local_step = local_step
# 循环dataloader, 不要试用itertools的cycle,遇到过变量不释放的问题
try:
batch = next(train_dataloader_iter)
except StopIteration:
self.callback_fun('dataloader_end') # 适用于数据量较大时,动态读取文件并重新生成dataloader的情况,如预训练
train_dataloader_iter = iter(self.train_dataloader) # shuffle=True时候,其实顺序也重新生成了
self.bti = 0
batch = next(train_dataloader_iter)
train_X, train_y = batch
# 取btz,最多允许嵌套两层,即((token_ids1, mask1), (token_ids2, mask2))
if isinstance(train_X, (list, tuple)):
if isinstance(train_X[0], (list, tuple)):
btz = train_X[0][0].size(0)
else:
btz = train_X[0].size(0)
elif isinstance(train_X, torch.Tensor):
btz = train_X.size(0)
else:
raise ValueError('Input only support [list, tuple, tensor]')
logs = {'batch': self.local_step, 'size': btz}
self.callback_fun('batch_begin', logs)
self.train() # 设置为train模式
# 入参个数判断,如果入参>=3表示是多个入参,如果=2则表示是一个入参
output, loss, loss_detail = self.train_step(train_X, train_y, grad_accumulation_steps)
retain_graph = True if self.adversarial['name'] in {'gradient_penalty', 'vat'} else False
if self.use_amp: # 混合精度
scale_before_step = self.scaler.get_scale()
self.scaler.scale(loss).backward(retain_graph=retain_graph)
else:
loss.backward(retain_graph=retain_graph)
# 对抗训练
loss, loss_detail = self.adversarial_training(train_X, train_y, output, loss, loss_detail, grad_accumulation_steps)
# 参数更新, 真实的参数更新次数要除以grad_accumulation_steps,注意调整总的训练步数
if (self.global_step+1) % grad_accumulation_steps == 0:
skip_scheduler = False
# 混合精度
if self.use_amp:
self.scaler.unscale_(self.optimizer)
if self.max_grad_norm is not None: # 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
skip_scheduler = self.scaler.get_scale() != scale_before_step
else:
if self.max_grad_norm is not None: # 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad() # 清梯度
if (self.scheduler is not None) and not skip_scheduler:
self.scheduler.step()
# 添加log打印
logs.update({'loss': loss.item()})
logs_loss_detail = {k: v.item() if isinstance(v, torch.Tensor) else v for k, v in loss_detail.items()}
logs.update(logs_loss_detail)
if self.global_step == 0:
self.callbacks[0].add_metrics(list(logs_loss_detail.keys()), add_position=1)
for metric in self.metrics:
tmp = metric_mapping(metric, output, train_y) # 内置的一些accuracy指标
if tmp is not None:
logs[metric] = tmp
self.callback_fun('batch_end', logs)
self.bti += 1
self.global_step += 1
self.callback_fun('epoch_end', logs)
# earlystop策略
callback_tmp = [callback_tmp for callback_tmp in self.callbacks if isinstance(callback_tmp, EarlyStopping)]
if callback_tmp and callback_tmp[0].stopped_epoch > 0:
break
self.callback_fun('train_end', logs)
@torch.no_grad()
def predict(self, input_tensor_list, return_all=None):
self.eval()
if self.forward.__code__.co_argcount >= 3:
output = self.forward(*input_tensor_list)
else:
output = self.forward(input_tensor_list)
if return_all is None:
return output
elif isinstance(output, (tuple, list)) and isinstance(return_all, int) and return_all < len(output):
return output[return_all]
else:
raise ValueError('Return format error')
def load_weights(self, load_path, strict=True, prefix=None):
state_dict = torch.load(load_path, map_location='cpu')
if prefix is None:
self.load_state_dict(state_dict, strict=strict)
else:
# 加载save_weights中to_raw_format=True的情形
eval_str = 'self.variable_mapping()' if prefix == '' else f'self.{prefix}.variable_mapping()'
mapping = {v:k for k, v in eval(eval_str).items()}
mapping = mapping if prefix == '' else {k:f'{prefix}.{v}' for k,v in mapping.items()}
state_dict_raw = {}
for k, v in state_dict.items():
k = mapping.get(k, k)
state_dict_raw[k] = v
self.load_state_dict(state_dict_raw, strict=strict)
def save_weights(self, save_path, prefix=None):
if prefix is None:
torch.save(self.state_dict(), save_path)
else:
# 按照variable_mapping()中原始的key保存,方便其他官方代码加载模型
eval_str = 'self.variable_mapping()' if prefix == '' else f'self.{prefix}.variable_mapping()'
mapping = eval(eval_str)
mapping = mapping if prefix == '' else {f'{prefix}.{k}':v for k,v in mapping.items()}
state_dict_raw = {}
for k, v in self.state_dict().items():
k = mapping.get(k, k)
state_dict_raw[k] = v
torch.save(state_dict_raw, save_path)
class BaseModelDP(BaseModel, nn.DataParallel):
'''DataParallel模式使用多gpu的方法
'''
def __init__(self, *args, **kwargs):
nn.DataParallel.__init__(self, *args, **kwargs)
class BaseModelDDP(BaseModel, nn.parallel.DistributedDataParallel):
'''DistributedDataParallel模式使用多gpu的方法
'''
def __init__(self, *args, master_rank=0, **kwargs):
self.master_rank = master_rank # 用于记录打印条的rank
nn.parallel.DistributedDataParallel.__init__(self, *args, **kwargs)
class BERT_BASE(BaseModel):
"""模型基类
"""
def __init__(
self,
vocab_size, # 词表大小
hidden_size, # 编码维度
num_hidden_layers, # Transformer总层数
num_attention_heads, # Attention的头数
intermediate_size, # FeedForward的隐层维度
hidden_act, # FeedForward隐层的激活函数
dropout_rate=None, # Dropout比例
attention_probs_dropout_prob=None, # Attention矩阵的Dropout比例
embedding_size=None, # 指定embedding_size, 不指定则使用config文件的参数
attention_head_size=None, # Attention中V的head_size
attention_key_size=None, # Attention中Q,K的head_size
initializer_range=0.02, # 权重初始化方差
sequence_length=None, # 是否固定序列长度
keep_tokens=None, # 要保留的词ID列表
compound_tokens=None, # 扩展Embedding
residual_attention_scores=False, # Attention矩阵加残差
ignore_invalid_weights=False, # 允许跳过不存在的权重
keep_hidden_layers=None, # 保留的hidden_layer层的id
hierarchical_position=None, # 是否层次分解位置编码
**kwargs
):
super(BERT_BASE, self).__init__()
if keep_tokens is not None:
vocab_size = len(keep_tokens)
if compound_tokens is not None:
vocab_size += len(compound_tokens)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_head_size = attention_head_size or self.hidden_size // self.num_attention_heads
self.attention_key_size = attention_key_size or self.attention_head_size
self.intermediate_size = intermediate_size
self.dropout_rate = dropout_rate or 0
self.attention_probs_dropout_prob = attention_probs_dropout_prob or 0
self.hidden_act = hidden_act
self.embedding_size = embedding_size or hidden_size
self.initializer_range = initializer_range
self.sequence_length = sequence_length
self.keep_tokens = keep_tokens
self.compound_tokens = compound_tokens
self.attention_bias = None
self.position_bias = None
self.attention_scores = None
self.residual_attention_scores = residual_attention_scores
self.ignore_invalid_weights = ignore_invalid_weights
self.keep_hidden_layers = set(range(num_hidden_layers)) if keep_hidden_layers is None else set(keep_hidden_layers)
self.hierarchical_position = hierarchical_position
def build(
self,
attention_caches=None,
layer_norm_cond=None,
layer_norm_cond_hidden_size=None,
layer_norm_cond_hidden_act=None,
additional_input_layers=None,
**kwargs
):
"""模型构建函数
attention_caches: 为Attention的K,V的缓存序列字典,格式为{Attention层名: [K缓存, V缓存]};
layer_norm_*系列参数: 实现Conditional Layer Normalization时使用,用来实现以“固定长度向量”为条件的条件Bert。
"""
# additional_input
# if additional_input_layers is not None:
# if not isinstance(additional_input_layers, list):
# self.additional_input_layers = [additional_input_layers]
# else:
# self.additional_input_layers = additional_input_layers
# Other
self.attention_caches = attention_caches or {}
# self.layer_norm_conds = [
# layer_norm_cond,
# layer_norm_cond_hidden_size,
# layer_norm_cond_hidden_act or 'linear',
# ]
self.output_all_encoded_layers = kwargs.get('output_all_encoded_layers', False)
def forward(self, inputs):
"""定义模型的执行流程
"""
# Embedding
outputs = self.apply_embeddings(inputs)
# Main
outputs = self.apply_main_layers(outputs)
# Final
outputs = self.apply_final_layers(outputs)
return outputs
def init_model_weights(self, module):
""" 初始化权重
"""
if isinstance(module, (nn.Linear, nn.Embedding)) and (module.weight.requires_grad):
# bert参数初始化, tf版本在linear和Embedding层使用的是截断正太分布, pytorch没有实现该函数,
# 此种初始化对于加载预训练模型后进行finetune没有任何影响,
# cf https://github.com/pytorch/pytorch/pull/5617
# 固定的相对位置编码如Sinusoidal无需初始化
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, LayerNorm):
if hasattr(module, 'bias') and module.bias.requires_grad: # T5等模型使用的是rmsnorm
module.bias.data.zero_()
if hasattr(module, 'weight') and module.weight.requires_grad:
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and (module.bias is not None) and (module.bias.requires_grad):
module.bias.data.zero_()
def variable_mapping(self):
"""构建pytorch层与checkpoint的变量名之间的映射表
"""
return {}
def load_load_variable(self):
raise NotImplementedError
def load_embeddings(self, embeddings):
"""根据keep_tokens和compound_tokens对embedding进行修改
"""
if self.keep_tokens is not None:
embeddings = embeddings[self.keep_tokens]
if self.compound_tokens is not None:
ext_embeddings = []
for item in self.compound_tokens:
try:
ext_embeddings.append(torch.mean(embeddings[item], 0) * torch.ones_like(embeddings[item]))
except IndexError:
ext_embeddings.append(torch.mean(embeddings, 0, keepdim=True))
warnings.warn(f'Initialize ext_embeddings from compound_tokens not in embedding index')
embeddings = torch.cat([embeddings] + ext_embeddings, 0)
return embeddings
def load_pos_embeddings(self, embeddings):
"""根据hierarchical_position对pos_embedding进行修改
"""
if self.hierarchical_position is not None:
alpha = 0.4 if self.hierarchical_position is True else self.hierarchical_position
embeddings = embeddings - alpha * embeddings[:1]
embeddings = embeddings / (1 - alpha)
position_index = torch.arange(self.max_position)[:, None]
# 为兼容低版本pytorch没有take_along_dim
embeddings_x = take_along_dim(embeddings, torch.div(position_index, embeddings.size(0), rounding_mode='trunc'), dim=0)
embeddings_y = take_along_dim(embeddings, position_index % embeddings.size(0), dim=0)
embeddings = alpha * embeddings_x + (1 - alpha) * embeddings_y
return embeddings
def load_weights_from_pytorch_checkpoint(self, checkpoint, mapping=None):
"""根据mapping从checkpoint加载权重
"""
file_state_dict = torch.load(checkpoint, map_location='cpu') # 加载模型文件
mapping = mapping or self.variable_mapping()
parameters_set = set([i[0] for i in self.named_parameters()]) # 可更新的变量
# 如果模型文件和模型结构中同时存在,且不在预设的mapping中,则更新mapping
# 主要是如为了在外部继承BERT后有其他layer,也能自动从checkpoint中加载进来
for layer_name in parameters_set:
if (layer_name in file_state_dict) and (layer_name not in mapping):
mapping.update({layer_name: layer_name})
state_dict_new ={}
for new_key, old_key in mapping.items():
if new_key not in self.state_dict():
continue
elif old_key in file_state_dict: # mapping中包含,且模型结构中有
state_dict_new[new_key] = self.load_variable(file_state_dict, old_key)
elif (old_key not in file_state_dict) and (not self.ignore_invalid_weights):
# mapping中包含,但模型文件中没有
print(f'[WARNIMG] {old_key} not found in pretrain models')
if new_key in parameters_set:
parameters_set.remove(new_key)
# 未能加载预训练权重的Parameter
if not self.ignore_invalid_weights:
for key in parameters_set:
print(f'[WARNIMG] Parameter {key} not loaded from pretrain models')
del file_state_dict
# 将ckpt的权重load到模型结构中
self.load_state_dict(state_dict_new, strict=False)
# def get_inputs(self):
# pass
# def set_inputs(self, inputs, additional_input_layers=None):
# """设置input和inputs属性
# """
# pass
def apply_embeddings(self, inputs):
raise NotImplementedError
def apply_main_layers(self, inputs):
raise NotImplementedError
def apply_final_layers(self, inputs):
raise NotImplementedError
def apply_on_layer_begin(self, l_i, inputs):
'''新增对layer block输入进行操作的函数
'''
return inputs
def apply_on_layer_end(self, l_i, inputs):
'''新增对layer block输出进行操作的函数
'''
return inputs
def compute_attention_bias(self, inputs=None):
"""定义每一层的Attention Bias
"""
return self.attention_bias
def compute_position_bias(self, inputs=None):
"""定义每一层的Position Bias(一般相对位置编码用)
"""
return self.position_bias
def set_outputs(self, outputs):
"""设置output和oututs属性
"""
if not isinstance(outputs, list):
outputs = [outputs]
outputs = outputs[:]
self.outputs = outputs
if len(outputs) > 1:
self.output = outputs
else:
self.output = outputs[0]
class LM_Mask(object):
"""定义下三角Attention Mask(语言模型用)
"""
def compute_attention_bias(self, inputs=None):
"""通过idxs序列的比较来得到对应的mask
"""
seq_len = inputs[0].shape[1]
attention_bias = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.long, device=inputs[0].device), diagonal=0)
self.attention_bias = attention_bias.unsqueeze(0).unsqueeze(1)
return self.attention_bias
def extend_with_language_model(InputModel):
"""添加下三角的Attention Mask(语言模型用)
"""
class LanguageModel(LM_Mask, InputModel):
"""带下三角Attention Mask的派生模型
"""
def __init__(self, *args, **kwargs):
kwargs['with_mlm'] = kwargs.get('with_mlm') or True
super(LanguageModel, self).__init__(*args, **kwargs)
return LanguageModel
class UniLM_Mask(object):
"""定义UniLM的Attention Mask(Seq2Seq模型用)
其中source和target的分区,由segment_ids来表示。
UniLM: https://arxiv.org/abs/1905.03197
"""
def compute_attention_bias(self, inputs=None):
"""通过idxs序列的比较来得到对应的mask
"""
segment_ids = inputs[1]
attention_bias = torch.cumsum(segment_ids, dim=1)
attention_bias = (attention_bias.unsqueeze(1)) <= (attention_bias.unsqueeze(2))
self.attention_bias = attention_bias.unsqueeze(1).long()
return self.attention_bias
def extend_with_unified_language_model(InputModel):
"""添加UniLM的Attention Mask(Seq2Seq模型用)
"""
class UnifiedLanguageModel(UniLM_Mask, InputModel):
"""带UniLM的Attention Mask的派生模型
UniLM: https://arxiv.org/abs/1905.03197
"""
def __init__(self, *args, **kwargs):
kwargs['with_mlm'] = kwargs.get('with_mlm') or True
super(UnifiedLanguageModel, self).__init__(*args, **kwargs)
return UnifiedLanguageModel
class BERT(BERT_BASE):
"""构建BERT模型
"""
def __init__(
self,
max_position, # 序列最大长度
segment_vocab_size=2, # segment总数目
with_pool=False, # 是否包含Pool部分
with_nsp=False, # 是否包含NSP部分
with_mlm=False, # 是否包含MLM部分
custom_position_ids=False, # 是否自行传入位置id
custom_attention_mask=False, # 是否自行传入attention_mask
shared_segment_embeddings=False, # 若True,则segment跟token共用embedding
layer_norm_cond=None, # conditional layer_norm
layer_add_embs=None, # addtional_embeddng, 比如加入词性,音调,word粒度的自定义embedding
is_dropout=False,
token_pad_ids=0, # 默认0是padding ids, 但是注意google的mt5padding不是0
**kwargs # 其余参数
):
super(BERT, self).__init__(**kwargs)
self.max_position = max_position
self.segment_vocab_size = segment_vocab_size
self.with_pool = with_pool
self.with_nsp = with_nsp
self.with_mlm = with_mlm
self.custom_position_ids = custom_position_ids
self.custom_attention_mask = custom_attention_mask
self.shared_segment_embeddings = shared_segment_embeddings
self.is_dropout = is_dropout
self.token_pad_ids = token_pad_ids
if self.with_nsp and not self.with_pool:
self.with_pool = True
self.layer_norm_conds = layer_norm_cond
self.layer_add_embs = layer_add_embs
self.conditional_size = layer_norm_cond.weight.size(1) if layer_norm_cond is not None else None
self.embeddings = BertEmbeddings(self.vocab_size, self.embedding_size, self.hidden_size, self.max_position, self.segment_vocab_size, self.shared_segment_embeddings,
self.dropout_rate, self.conditional_size, **get_kw(BertEmbeddings, kwargs))
kwargs['max_position'] = self.max_position # 相对位置编码需要使用
layer = BertLayer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size, self.hidden_act,
is_dropout=self.is_dropout, conditional_size=self.conditional_size, **get_kw(BertLayer, kwargs))
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) if layer_id in self.keep_hidden_layers else Identity() for layer_id in range(self.num_hidden_layers)])
if self.with_pool:
# Pooler部分(提取CLS向量)
self.pooler = nn.Linear(self.hidden_size, self.hidden_size)
self.pooler_activation = nn.Tanh() if self.with_pool is True else get_activation(self.with_pool)
if self.with_nsp:
# Next Sentence Prediction部分
# nsp的输入为pooled_output, 所以with_pool为True是使用nsp的前提条件
self.nsp = nn.Linear(self.hidden_size, 2)
else:
self.pooler = None
self.pooler_activation = None
if self.with_mlm:
self.mlmDense = nn.Linear(self.hidden_size, self.hidden_size)
self.transform_act_fn = get_activation(self.hidden_act)
self.mlmLayerNorm = LayerNorm(self.hidden_size, eps=1e-12, conditional_size=self.conditional_size)
self.mlmDecoder = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
if kwargs.get('tie_emb_prj_weight') is True:
self.mlmDecoder.weight = self.embeddings.word_embeddings.weight
self.mlmBias = nn.Parameter(torch.zeros(self.vocab_size))
self.mlmDecoder.bias = self.mlmBias
# 下述继承于BERT的有声明新的参数,在这里初始化不能统一初始化到
def apply_embeddings(self, inputs):
"""BERT的embedding是token、position、segment三者embedding之和
默认顺序是token_ids, segment_ids(若有), position_ids(若有), custom_attention_mask(若有), conditional_input(若有)
"""
token_ids = inputs[0]
index_ = 1
if self.segment_vocab_size > 0:
segment_ids = inputs[index_]
index_ += 1
else:
segment_ids = None
if self.custom_position_ids: # 暂未使用到,暂保留
position_ids = inputs[index_]
index_ += 1
else:
position_ids = None
# 根据token_ids创建一个3D的attention mask矩阵,尺寸为[batch_size, 1, 1, to_seq_length],
# 目的是为了适配多头注意力机制,从而能广播到[batch_size, num_heads, from_seq_length, to_seq_length]尺寸
if self.custom_attention_mask:
attention_mask = inputs[index_].long().unsqueeze(1).unsqueeze(2)
index_ += 1
elif (not token_ids.requires_grad) and (token_ids.dtype in {torch.long, torch.int}): # 正常的token_ids
attention_mask = (token_ids != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2) # 默认0为mask_value
if self.token_pad_ids < 0:
token_ids = token_ids * attention_mask[:,0,0,:]
else: # 自定义word_embedding,目前仅有VAT中使用
attention_mask = self.attention_mask_cache
self.attention_mask_cache = attention_mask # 缓存上次用的attention_mask
self.compute_attention_bias([token_ids, segment_ids]) # 根据lm或者unilm需要对mask做调整
if self.attention_bias is not None:
attention_mask = attention_mask * self.attention_bias # 不可访问padding
# attention_mask = self.attention_bias # 可以访问padding
# pytorch >= 1.5时候会导致StopIteration错误
# https://github.com/huggingface/transformers/issues/3936
# https://github.com/huggingface/transformers/issues/4189
# https://github.com/huggingface/transformers/issues/3936
try:
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # 兼容fp16
except StopIteration:
attention_mask = attention_mask.to(dtype=torch.float32)
# 对mask矩阵中,数值为0的转换成很大的负数,使得不需要attention的位置经过softmax后,分数趋近于0
# attention_mask = (1.0 - attention_mask) * -10000.0
# conditional layer_norm
if self.layer_norm_conds is None:
conditional_emb = None
else:
conditional_emb = self.layer_norm_conds(inputs[index_])
index_ += 1
# addtional_embeddng, 比如加入词性,音调,word粒度的自定义embedding
if isinstance(self.layer_add_embs, nn.Module): # 单个
additional_embs = [self.layer_add_embs(inputs[index_])]
index_ += 1
elif isinstance(self.layer_add_embs, (tuple, list)): # 多个
additional_embs = []
for layer in self.layer_add_embs:
assert isinstance(layer, nn.Module), 'Layer_add_embs element should be nn.Module'
additional_embs.append(layer(inputs[index_]))
index_ += 1
else:
additional_embs = None
# 进入embedding层
hidden_states = self.embeddings(token_ids, segment_ids, conditional_emb, additional_embs)
return [hidden_states, attention_mask, conditional_emb] + inputs[index_:]
def apply_main_layers(self, inputs):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
默认第一个是hidden_states, 第二个是attention_mask, 第三个是conditional_emb
"""
hidden_states, attention_mask, conditional_emb = inputs[:3]
if len(inputs[3:]) >= 2:
encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4]
else:
encoder_hidden_state, encoder_attention_mask = None, None
encoded_layers = [hidden_states] # 添加embedding的输出
layer_inputs = [hidden_states, attention_mask, conditional_emb, encoder_hidden_state, encoder_attention_mask]
for l_i, layer_module in enumerate(self.encoderLayer):
layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs)
hidden_states = layer_module(*layer_inputs)
layer_inputs[0] = hidden_states
layer_inputs = self.apply_on_layer_end(l_i, layer_inputs)
if self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
if not self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
return [encoded_layers, conditional_emb]
def apply_final_layers(self, inputs):
"""根据剩余参数决定输出
"""
# 获取最后一层隐藏层的输出
encoded_layers, conditional_emb = inputs
sequence_output = encoded_layers[-1]
# 是否取最后一层输出
if not self.output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
# 是否添加pool层
if self.with_pool:
pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0]))
else:
pooled_output = None
# 是否添加nsp
if self.with_pool and self.with_nsp:
nsp_scores = self.nsp(pooled_output)
else:
nsp_scores = None
# 是否添加mlm
if self.with_mlm:
mlm_hidden_state = self.mlmDense(sequence_output)
mlm_hidden_state = self.transform_act_fn(mlm_hidden_state)
mlm_hidden_state = self.mlmLayerNorm((mlm_hidden_state, conditional_emb))
mlm_scores = self.mlmDecoder(mlm_hidden_state)
mlm_activation = get_activation('linear' if self.with_mlm is True else self.with_mlm)
mlm_scores = mlm_activation(mlm_scores)
else:
mlm_scores = None
outputs = [value for value in [encoded_layers, pooled_output, mlm_scores, nsp_scores] if value is not None]
return outputs if len(outputs) > 1 else outputs[0]
def load_variable(self, state_dict, name, prefix='bert'):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {
f'{prefix}.embeddings.word_embeddings.weight',
'cls.predictions.bias',
'cls.predictions.decoder.weight',
'cls.predictions.decoder.bias'
}:
return self.load_embeddings(variable)
elif name == f'{prefix}.embeddings.position_embeddings.weight':
return self.load_pos_embeddings(variable)
elif name == 'cls.seq_relationship.weight':
return variable.T
else:
return variable
def variable_mapping(self, prefix='bert'):
mapping = {
'embeddings.word_embeddings.weight': f'{prefix}.embeddings.word_embeddings.weight',
'embeddings.position_embeddings.weight': f'{prefix}.embeddings.position_embeddings.weight',
'embeddings.segment_embeddings.weight': f'{prefix}.embeddings.token_type_embeddings.weight',
'embeddings.layerNorm.weight': f'{prefix}.embeddings.LayerNorm.weight',
'embeddings.layerNorm.bias': f'{prefix}.embeddings.LayerNorm.bias',
'pooler.weight': f'{prefix}.pooler.dense.weight',
'pooler.bias': f'{prefix}.pooler.dense.bias',
'nsp.weight': 'cls.seq_relationship.weight',
'nsp.bias': 'cls.seq_relationship.bias',
'mlmDense.weight': 'cls.predictions.transform.dense.weight',
'mlmDense.bias': 'cls.predictions.transform.dense.bias',
'mlmLayerNorm.weight': 'cls.predictions.transform.LayerNorm.weight',
'mlmLayerNorm.bias': 'cls.predictions.transform.LayerNorm.bias',
'mlmBias': 'cls.predictions.bias',
'mlmDecoder.weight': 'cls.predictions.decoder.weight',
'mlmDecoder.bias': 'cls.predictions.decoder.bias'
}
for i in range(self.num_hidden_layers):
prefix_i = f'{prefix}.encoder.layer.%d.' % i
mapping.update({f'encoderLayer.{i}.multiHeadAttention.q.weight': prefix_i + 'attention.self.query.weight',
f'encoderLayer.{i}.multiHeadAttention.q.bias': prefix_i + 'attention.self.query.bias',
f'encoderLayer.{i}.multiHeadAttention.k.weight': prefix_i + 'attention.self.key.weight',
f'encoderLayer.{i}.multiHeadAttention.k.bias': prefix_i + 'attention.self.key.bias',
f'encoderLayer.{i}.multiHeadAttention.v.weight': prefix_i + 'attention.self.value.weight',
f'encoderLayer.{i}.multiHeadAttention.v.bias': prefix_i + 'attention.self.value.bias',
f'encoderLayer.{i}.multiHeadAttention.o.weight': prefix_i + 'attention.output.dense.weight',
f'encoderLayer.{i}.multiHeadAttention.o.bias': prefix_i + 'attention.output.dense.bias',
f'encoderLayer.{i}.layerNorm1.weight': prefix_i + 'attention.output.LayerNorm.weight',
f'encoderLayer.{i}.layerNorm1.bias': prefix_i + 'attention.output.LayerNorm.bias',
f'encoderLayer.{i}.feedForward.intermediateDense.weight': prefix_i + 'intermediate.dense.weight',
f'encoderLayer.{i}.feedForward.intermediateDense.bias': prefix_i + 'intermediate.dense.bias',
f'encoderLayer.{i}.feedForward.outputDense.weight': prefix_i + 'output.dense.weight',
f'encoderLayer.{i}.feedForward.outputDense.bias': prefix_i + 'output.dense.bias',
f'encoderLayer.{i}.layerNorm2.weight': prefix_i + 'output.LayerNorm.weight',
f'encoderLayer.{i}.layerNorm2.bias': prefix_i + 'output.LayerNorm.bias'
})
return mapping
class ALBERT(BERT):
def __init__(self, *args, **kwargs):
super(ALBERT, self).__init__(*args, **kwargs)
self.encoderLayer = nn.ModuleList([self.encoderLayer[0]]) # 取上述的第一行
def apply_main_layers(self, inputs):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states, attention_mask, conditional_emb = inputs[:3]
if len(inputs[3:]) >= 2:
encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4]
else:
encoder_hidden_state, encoder_attention_mask = None, None
encoded_layers = [hidden_states] # 添加embedding的输出
layer_inputs = [hidden_states, attention_mask, conditional_emb, encoder_hidden_state, encoder_attention_mask]
for l_i in range(self.num_hidden_layers):
layer_inputs = self.apply_on_layer_begin(l_i, layer_inputs)
hidden_states = self.encoderLayer[0](*layer_inputs)
layer_inputs[0] = hidden_states
layer_inputs = self.apply_on_layer_end(l_i, layer_inputs)
if self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
if not self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
return [encoded_layers, conditional_emb]
def variable_mapping(self, prefix='albert'):
mapping = {
'embeddings.word_embeddings.weight': f'{prefix}.embeddings.word_embeddings.weight',
'embeddings.position_embeddings.weight': f'{prefix}.embeddings.position_embeddings.weight',
'embeddings.segment_embeddings.weight': f'{prefix}.embeddings.token_type_embeddings.weight',
'embeddings.layerNorm.weight': f'{prefix}.embeddings.LayerNorm.weight',
'embeddings.layerNorm.bias': f'{prefix}.embeddings.LayerNorm.bias',
'embeddings.embedding_hidden_mapping_in.weight': f'{prefix}.encoder.embedding_hidden_mapping_in.weight',
'embeddings.embedding_hidden_mapping_in.bias': f'{prefix}.encoder.embedding_hidden_mapping_in.bias',
'pooler.weight': f'{prefix}.pooler.weight',
'pooler.bias': f'{prefix}.pooler.bias',
'nsp.weight': 'sop_classifier.classifier.weight', # 用名字nsp来替换sop
'nsp.bias': 'sop_classifier.classifier.bias',
'mlmDense.weight': 'predictions.dense.weight',
'mlmDense.bias': 'predictions.dense.bias',
'mlmLayerNorm.weight': 'predictions.LayerNorm.weight',
'mlmLayerNorm.bias': 'predictions.LayerNorm.bias',
'mlmBias': 'predictions.bias',
'mlmDecoder.weight': 'predictions.decoder.weight',
'mlmDecoder.bias': 'predictions.decoder.bias'
}
i = 0
prefix_i = f'{prefix}.encoder.albert_layer_groups.{i}.albert_layers.{i}.'
mapping.update({f'encoderLayer.{i}.multiHeadAttention.q.weight': prefix_i + 'attention.query.weight',
f'encoderLayer.{i}.multiHeadAttention.q.bias': prefix_i + 'attention.query.bias',
f'encoderLayer.{i}.multiHeadAttention.k.weight': prefix_i + 'attention.key.weight',
f'encoderLayer.{i}.multiHeadAttention.k.bias': prefix_i + 'attention.key.bias',
f'encoderLayer.{i}.multiHeadAttention.v.weight': prefix_i + 'attention.value.weight',
f'encoderLayer.{i}.multiHeadAttention.v.bias': prefix_i + 'attention.value.bias',
f'encoderLayer.{i}.multiHeadAttention.o.weight': prefix_i + 'attention.dense.weight',
f'encoderLayer.{i}.multiHeadAttention.o.bias': prefix_i + 'attention.dense.bias',
f'encoderLayer.{i}.layerNorm1.weight': prefix_i + 'attention.LayerNorm.weight',
f'encoderLayer.{i}.layerNorm1.bias': prefix_i + 'attention.LayerNorm.bias',
f'encoderLayer.{i}.feedForward.intermediateDense.weight': prefix_i + 'ffn.weight',
f'encoderLayer.{i}.feedForward.intermediateDense.bias': prefix_i + 'ffn.bias',
f'encoderLayer.{i}.feedForward.outputDense.weight': prefix_i + 'ffn_output.weight',
f'encoderLayer.{i}.feedForward.outputDense.bias': prefix_i + 'ffn_output.bias',
f'encoderLayer.{i}.layerNorm2.weight': prefix_i + 'full_layer_layer_norm.weight',
f'encoderLayer.{i}.layerNorm2.bias': prefix_i + 'full_layer_layer_norm.bias'
})
return mapping
def load_variable(self, state_dict, name):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {
'albert.embeddings.word_embeddings.weight',
'predictions.bias',
'predictions.decoder.weight',
'predictions.decoder.bias'
}:
return self.load_embeddings(variable)
elif name == 'albert.embeddings.position_embeddings.weight':
return self.load_pos_embeddings(variable)
elif name == 'sop_classifier.classifier.weight':
return variable.T
else:
return variable
class ALBERT_Unshared(ALBERT):
def __init__(self, *args, **kwargs):
super(ALBERT_Unshared).__init__(*args, **kwargs)
self.encoderLayer = nn.ModuleList([copy.deepcopy(self.encoderLayer[0]) for _ in range(self.num_hidden_layers)])
def apply_main_layers(self, inputs):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states, attention_mask, conditional_emb = inputs
if len(inputs[3:]) >= 2:
encoder_hidden_state, encoder_attention_mask = inputs[3], inputs[4]
else:
encoder_hidden_state, encoder_attention_mask = None, None
encoded_layers = [hidden_states] # 添加embedding的输出
layer_inputs = [hidden_states, attention_mask, conditional_emb, encoder_hidden_state, encoder_attention_mask]
for i in range(self.num_hidden_layers):
layer_inputs = self.apply_on_layer_begin(i, layer_inputs)
hidden_states = self.encoderLayer[i](*layer_inputs)
layer_inputs[0] = hidden_states
layer_inputs = self.apply_on_layer_end(i, layer_inputs)
if self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
if not self.output_all_encoded_layers:
encoded_layers.append(hidden_states)
return [encoded_layers, conditional_emb]
class NEZHA(BERT):
"""华为推出的NAZHA模型
链接:https://arxiv.org/abs/1909.00204
"""
def __init__(self, *args, **kwargs):
# p_bias来控制embedding阶段无pos_embedding, max_relative_position默认取64
kwargs.update({'p_bias': 'typical_relative', 'max_relative_position': kwargs.get('max_relative_position', 64)})
super(NEZHA, self).__init__(*args, **kwargs)
class RoFormer(BERT):
"""旋转式位置编码的BERT模型
链接:https://kexue.fm/archives/8265
"""
def __init__(self, *args, **kwargs):
kwargs.update({'p_bias': 'rotary'})
super(RoFormer, self).__init__(*args, **kwargs)
def load_variable(self, state_dict, name, prefix='roformer'):
return super().load_variable(state_dict, name, prefix)
def variable_mapping(self, prefix='roformer'):
mapping = super().variable_mapping(prefix)
del mapping['embeddings.position_embeddings.weight'] # 没有位置编码
return mapping
class RoFormerV2(RoFormer):
"""RoFormerV2
改动:去掉bias,简化Norm,优化初始化等。目前初始化暂时还用的bert的初始化,finetune不受影响
"""
@delete_arguments('with_pool', 'with_nsp')
def __init__(self, *args, **kwargs):
kwargs.update({'p_bias': 'rotary', 'weight': False, 'bias': False, 'norm_mode': 'rmsnorm'})
super(RoFormerV2, self).__init__(*args, **kwargs)
if self.with_mlm:
del self.mlmLayerNorm
del self.mlmBias
del self.mlmDense
self.mlmDecoder.register_parameter('bias', None)
def variable_mapping(self, prefix='roformer'):
mapping = super().variable_mapping(prefix)
mapping_new = {}
for k, v in mapping.items():
if (not re.search('bias|layernorm', k.lower())) and (not re.search('bias|layernorm', v.lower())):
mapping_new[k] = v
return mapping_new
def apply_final_layers(self, inputs):
"""根据剩余参数决定输出
"""
# 获取最后一层隐藏层的输出
encoded_layers, conditional_emb = inputs
sequence_output = encoded_layers[-1]
# 是否取最后一层输出
if not self.output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
# 是否添加mlm
if self.with_mlm:
mlm_scores = self.mlmDecoder(sequence_output)
else:
mlm_scores = None
outputs = [value for value in [encoded_layers, mlm_scores] if value is not None]
return outputs if len(outputs) > 1 else outputs[0]
class GAU_alpha(RoFormerV2):
def __init__(self, *args, **kwargs):
kwargs.update({'p_bias': 'rotary', 'weight': False, 'bias': False, 'norm_mode': 'rmsnorm', 'normalization': 'softmax_plus'})
super().__init__(*args, **kwargs)
layer = self.GAU_Layer(**kwargs)
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) if layer_id in self.keep_hidden_layers else Identity() for layer_id in range(self.num_hidden_layers)])
def load_variable(self, state_dict, name, prefix=''):
variable = state_dict[name]
return self.load_embeddings(variable) if name in {'embeddings.word_embeddings.weight', 'mlmDecoder.weight'} else variable
def variable_mapping(self, prefix=''):
'''在convert脚本里已经把key转成bert4torch可用的
'''
return {k: k for k, _ in self.named_parameters()}
class GAU_Layer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.gau = GatedAttentionUnit(**kwargs)
self.dropout1 = nn.Dropout(kwargs.get('dropout_rate'))
self.layerNorm1 = LayerNorm(**kwargs)
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
gau_hidden_states = self.gau(hidden_states, attention_mask)
hidden_states = hidden_states + self.dropout1(gau_hidden_states)
hidden_states = self.layerNorm1((hidden_states, conditional_emb))
return hidden_states
class ELECTRA(BERT):
"""Google推出的ELECTRA模型
链接:https://arxiv.org/abs/2003.10555
"""
@insert_arguments(with_discriminator=False)
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, max_position, **kwargs):
super(ELECTRA, self).__init__(max_position, **kwargs)
if self.with_discriminator:
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.dense_act = get_activation(self.hidden_act)
self.dense_prediction = nn.Linear(self.hidden_size, 1)
self.dense_prediction_act = get_activation('sigmoid') if self.with_discriminator is True else get_activation(self.with_discriminator)
def apply_final_layers(self, inputs):
hidden_states = super().apply_final_layers(inputs) # 仅有hidden_state一项输出
if self.with_discriminator:
logit = self.dense_act(self.dense(hidden_states))
return [hidden_states, self.dense_prediction_act(self.dense_prediction(logit))]
else:
return hidden_states
def load_variable(self, state_dict, name):
"""加载单个变量的函数
"""
return super().load_variable(state_dict, name, prefix='electra')
def variable_mapping(self):
mapping = super(ELECTRA, self).variable_mapping(prefix='electra')
mapping.update({'dense.weight': 'discriminator_predictions.dense.weight',
'dense.bias': 'discriminator_predictions.dense.bias',
'dense_prediction.weight': 'discriminator_predictions.dense_prediction.weight',
'dense_prediction.bias': 'discriminator_predictions.dense_prediction.bias'}
)
for del_key in ['pooler.weight', 'pooler.bias', 'nsp.weight', 'nsp.bias', 'mlmDense.weight', 'mlmDense.bias',
'mlmLayerNorm.weight', 'mlmLayerNorm.bias', 'mlmBias', 'mlmDecoder.weight', 'mlmDecoder.bias']:
del mapping[del_key]
return mapping
class Encoder(BERT):
def __init__(self, *args, **kwargs):
kwargs['vocab_size'] = kwargs.get('src_vocab_size', kwargs['vocab_size'])
super().__init__(*args, **kwargs)
# encoder需要返回encoder_attention_mask
self.encoder_attention_mask = None
def forward(self, inputs):
"""因为encoder需要返回encoder_attention_mask,因此这里从新定义一下,多返回一个参数
"""
# Embedding
outputs = self.apply_embeddings(inputs)
encoder_attention_mask = [outputs[1]]
# Main
outputs = self.apply_main_layers(outputs)
# Final
outputs = self.apply_final_layers(outputs)
return ([outputs] if isinstance(outputs, torch.Tensor) else outputs) + encoder_attention_mask
class Decoder(LM_Mask, BERT):
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, *args, with_lm=True, tie_emb_prj_weight=True, **kwargs):
kwargs['vocab_size'] = kwargs.get('tgt_vocab_size', kwargs['vocab_size'])
kwargs['is_decoder'] = True # 标记是decoder
super().__init__(*args, **kwargs)
self.decoderLayer = self.encoderLayer
del self.encoderLayer
self.with_lm = with_lm
# 从hidden_states映射到logit
if self.with_lm:
self.final_dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
if tie_emb_prj_weight: # decoder底层的embedding和顶层的全连接共享
self.final_dense.weight = self.embeddings.word_embeddings.weight
self.x_logit_scale = (self.hidden_size ** -0.5)
else:
self.x_logit_scale = 1.
def apply_main_layers(self, inputs):
"""Dencoder主体是基于Self-Attention、Cross-Attention的模块
顺序:Att1 --> Add --> LN --> Att2 --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states, attention_mask, conditional_emb, encoder_hidden_state, encoder_attention_mask = inputs[:5]
decoded_layers = [hidden_states] # 添加embedding的输出
layer_inputs = [hidden_states, attention_mask, conditional_emb, encoder_hidden_state, encoder_attention_mask]
for i, layer_module in enumerate(self.decoderLayer):
layer_inputs = self.apply_on_layer_begin(i, layer_inputs)
hidden_states = layer_module(*layer_inputs)
layer_inputs[0] = hidden_states
layer_inputs = self.apply_on_layer_end(i, layer_inputs)
if self.output_all_encoded_layers:
decoded_layers.append(hidden_states)
if not self.output_all_encoded_layers:
decoded_layers.append(hidden_states)
return [decoded_layers, conditional_emb]
def apply_final_layers(self, inputs):
outputs = []
hidden_states = super().apply_final_layers(inputs) # outputs为decoder顶层的hidden_states [btz, seq_len, hdsz]
outputs.append(hidden_states)
if self.with_lm:
logits = self.final_dense(hidden_states) * self.x_logit_scale # outputs为[btz, seq_len, vocab_size]的logits
activation = get_activation('linear' if self.with_lm is True else self.with_lm) # 添加激活,一般是线性激活或softmax
logits = activation(logits)
outputs.append(logits)
return outputs
def variable_mapping(self, prefix='bert'):
raw_mapping = super().variable_mapping(prefix)
mapping = {}
for k, v in raw_mapping.items():
mapping[k.replace('encoderLayer', 'decoderLayer')] = v
# for i in range(self.num_hidden_layers):
# prefix_i = f'{prefix}.encoder.layer.%d.' % i
# mapping.update({
# f'decoderLayer.{i}.crossAttention.q.weight': prefix_i + 'crossattention.self.query.weight',
# f'decoderLayer.{i}.crossAttention.q.bias': prefix_i + 'crossattention.self.query.bias',
# f'decoderLayer.{i}.crossAttention.k.weight': prefix_i + 'crossattention.self.key.weight',
# f'decoderLayer.{i}.crossAttention.k.bias': prefix_i + 'crossattention.self.key.bias',
# f'decoderLayer.{i}.crossAttention.v.weight': prefix_i + 'crossattention.self.value.weight',
# f'decoderLayer.{i}.crossAttention.v.bias': prefix_i + 'crossattention.self.value.bias',
# f'decoderLayer.{i}.crossAttention.o.weight': prefix_i + 'crossattention.output.dense.weight',
# f'decoderLayer.{i}.crossAttention.o.bias': prefix_i + 'crossattention.output.dense.bias',
# f'decoderLayer.{i}.layerNorm3.weight': prefix_i + 'crossattention.output.LayerNorm.weight',
# f'decoderLayer.{i}.layerNorm3.bias': prefix_i + 'crossattention.output.LayerNorm.bias'
# })
return mapping
class Transformer(BERT_BASE):
'''encoder-decoder结构
'''
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, *args, tie_emb_src_tgt_weight=False, **kwargs):
super(Transformer, self).__init__(*args, **kwargs)
# encoder
self.encoder = Encoder(*args, **kwargs)
self.encoder.build(**kwargs)
# decoder
self.decoder = Decoder(*args, **kwargs)
self.decoder.build(**kwargs)
if tie_emb_src_tgt_weight:
# encoder和decoder的embedding权重共享
assert self.encoder.vocab_size == self.decoder.vocab_size, "To share word embedding, the vocab size of src/tgt shall be the same."
self.encoder.embeddings.word_embeddings.weight = self.decoder.embeddings.word_embeddings.weight
def forward(self, inputs):
"""定义模型的执行流程
"""
encoder_input, decoder_input = inputs[:2]
# encoder
# encoder_emb = self.encoder.apply_embeddings(encoder_input)
# encode_outputs = self.encoder.apply_main_layers(encoder_emb)
# encoder_hidden_state = self.encoder.apply_final_layers(encode_outputs)
# encoder_attention_mask = encoder_emb[1]
encoder_hidden_state, encoder_attention_mask = self.encoder(encoder_input)
# decoder
# decoder_emb = self.decoder.apply_embeddings(decoder_input)
# decoder_outputs = self.decoder.apply_main_layers([*decoder_emb, encoder_hidden_state, encoder_attention_mask])
# decoder_outputs = self.decoder.apply_final_layers(decoder_outputs) # [hidden_states, logits]
decoder_outputs = self.decoder(decoder_input + [encoder_hidden_state, encoder_attention_mask])
return [encoder_hidden_state] + decoder_outputs # 输出encoder_hidden_state和decoder_hidden_state,以应对一些多任务情况
class BART(Transformer):
'''encoder-decoder结构
'''
def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs):
super(BART, self).__init__(*args, tie_emb_src_tgt_weight=tie_emb_src_tgt_weight, **kwargs)
self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight
def load_variable(self, state_dict, name, prefix=''):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {
'shared.weight',
'encoder.embed_tokens.weight',
'decoder.embed_tokens.weight',
}:
return self.load_embeddings(variable)
elif name in {'encoder.embed_positions.weight', 'decoder.embed_positions.weight'}:
return self.load_pos_embeddings(variable)
else:
return variable
def variable_mapping(self, prefix=''):
# 查看check_point发现'shared.weight'
mapping = {
'encoder.embeddings.word_embeddings.weight': 'shared.weight' if self.tie_emb_src_tgt_weight else 'encoder.embed_tokens.weight',
'encoder.embeddings.position_embeddings.weight': 'encoder.embed_positions.weight',
'encoder.embeddings.layerNorm.weight': 'encoder.layernorm_embedding.weight',
'encoder.embeddings.layerNorm.bias': 'encoder.layernorm_embedding.bias',
'decoder.embeddings.word_embeddings.weight': 'shared.weight' if self.tie_emb_src_tgt_weight else 'decoder.embed_tokens.weight',
'decoder.embeddings.position_embeddings.weight': 'decoder.embed_positions.weight',
'decoder.embeddings.layerNorm.weight': 'decoder.layernorm_embedding.weight',
'decoder.embeddings.layerNorm.bias': 'decoder.layernorm_embedding.bias',
}
for i in range(self.num_hidden_layers):
mapping.update(
{
f'encoder.encoderLayer.{i}.multiHeadAttention.q.weight': f'encoder.layers.{i}.self_attn.q_proj.weight',
f'encoder.encoderLayer.{i}.multiHeadAttention.q.bias': f'encoder.layers.{i}.self_attn.q_proj.bias',
f'encoder.encoderLayer.{i}.multiHeadAttention.k.weight': f'encoder.layers.{i}.self_attn.k_proj.weight',
f'encoder.encoderLayer.{i}.multiHeadAttention.k.bias': f'encoder.layers.{i}.self_attn.k_proj.bias',
f'encoder.encoderLayer.{i}.multiHeadAttention.v.weight': f'encoder.layers.{i}.self_attn.v_proj.weight',
f'encoder.encoderLayer.{i}.multiHeadAttention.v.bias': f'encoder.layers.{i}.self_attn.v_proj.bias',
f'encoder.encoderLayer.{i}.multiHeadAttention.o.weight': f'encoder.layers.{i}.self_attn.out_proj.weight',
f'encoder.encoderLayer.{i}.multiHeadAttention.o.bias': f'encoder.layers.{i}.self_attn.out_proj.bias',
f'encoder.encoderLayer.{i}.layerNorm1.weight': f'encoder.layers.{i}.self_attn_layer_norm.weight',
f'encoder.encoderLayer.{i}.layerNorm1.bias': f'encoder.layers.{i}.self_attn_layer_norm.bias',
f'encoder.encoderLayer.{i}.feedForward.intermediateDense.weight': f'encoder.layers.{i}.fc1.weight',
f'encoder.encoderLayer.{i}.feedForward.intermediateDense.bias': f'encoder.layers.{i}.fc1.bias',
f'encoder.encoderLayer.{i}.feedForward.outputDense.weight': f'encoder.layers.{i}.fc2.weight',
f'encoder.encoderLayer.{i}.feedForward.outputDense.bias': f'encoder.layers.{i}.fc2.bias',
f'encoder.encoderLayer.{i}.layerNorm2.weight': f'encoder.layers.{i}.final_layer_norm.weight',
f'encoder.encoderLayer.{i}.layerNorm2.bias': f'encoder.layers.{i}.final_layer_norm.bias',
f'decoder.decoderLayer.{i}.multiHeadAttention.q.weight': f'decoder.layers.{i}.self_attn.q_proj.weight',
f'decoder.decoderLayer.{i}.multiHeadAttention.q.bias': f'decoder.layers.{i}.self_attn.q_proj.bias',
f'decoder.decoderLayer.{i}.multiHeadAttention.k.weight': f'decoder.layers.{i}.self_attn.k_proj.weight',
f'decoder.decoderLayer.{i}.multiHeadAttention.k.bias': f'decoder.layers.{i}.self_attn.k_proj.bias',
f'decoder.decoderLayer.{i}.multiHeadAttention.v.weight': f'decoder.layers.{i}.self_attn.v_proj.weight',
f'decoder.decoderLayer.{i}.multiHeadAttention.v.bias': f'decoder.layers.{i}.self_attn.v_proj.bias',
f'decoder.decoderLayer.{i}.multiHeadAttention.o.weight': f'decoder.layers.{i}.self_attn.out_proj.weight',
f'decoder.decoderLayer.{i}.multiHeadAttention.o.bias': f'decoder.layers.{i}.self_attn.out_proj.bias',
f'decoder.decoderLayer.{i}.layerNorm1.weight': f'decoder.layers.{i}.self_attn_layer_norm.weight',
f'decoder.decoderLayer.{i}.layerNorm1.bias': f'decoder.layers.{i}.self_attn_layer_norm.bias',
f'decoder.decoderLayer.{i}.crossAttention.q.weight': f'decoder.layers.{i}.encoder_attn.q_proj.weight',
f'decoder.decoderLayer.{i}.crossAttention.q.bias': f'decoder.layers.{i}.encoder_attn.q_proj.bias',
f'decoder.decoderLayer.{i}.crossAttention.k.weight': f'decoder.layers.{i}.encoder_attn.k_proj.weight',
f'decoder.decoderLayer.{i}.crossAttention.k.bias': f'decoder.layers.{i}.encoder_attn.k_proj.bias',
f'decoder.decoderLayer.{i}.crossAttention.v.weight': f'decoder.layers.{i}.encoder_attn.v_proj.weight',
f'decoder.decoderLayer.{i}.crossAttention.v.bias': f'decoder.layers.{i}.encoder_attn.v_proj.bias',
f'decoder.decoderLayer.{i}.crossAttention.o.weight': f'decoder.layers.{i}.encoder_attn.out_proj.weight',
f'decoder.decoderLayer.{i}.crossAttention.o.bias': f'decoder.layers.{i}.encoder_attn.out_proj.bias',
f'decoder.decoderLayer.{i}.layerNorm3.weight': f'decoder.layers.{i}.encoder_attn_layer_norm.weight',
f'decoder.decoderLayer.{i}.layerNorm3.bias': f'decoder.layers.{i}.encoder_attn_layer_norm.bias',
f'decoder.decoderLayer.{i}.feedForward.intermediateDense.weight': f'decoder.layers.{i}.fc1.weight',
f'decoder.decoderLayer.{i}.feedForward.intermediateDense.bias': f'decoder.layers.{i}.fc1.bias',
f'decoder.decoderLayer.{i}.feedForward.outputDense.weight': f'decoder.layers.{i}.fc2.weight',
f'decoder.decoderLayer.{i}.feedForward.outputDense.bias': f'decoder.layers.{i}.fc2.bias',
f'decoder.decoderLayer.{i}.layerNorm2.weight': f'decoder.layers.{i}.final_layer_norm.weight',
f'decoder.decoderLayer.{i}.layerNorm2.bias': f'decoder.layers.{i}.final_layer_norm.bias'
})
return mapping
class T5_Encoder(Encoder):
@insert_arguments(version='t5.1.0')
def __init__(self, *args, **kwargs):
kwargs.update({'p_bias': 't5_relative', 'relative_attention_num_buckets': kwargs.get('relative_attention_num_buckets'), 'version': self.version,
'bias': False, 'norm_mode': 'rmsnorm'}) # p_bias来控制embedding阶段无pos_embedding,t5不使用bias,并且使用rmsnorm
super().__init__(*args, **kwargs)
del self.embeddings.layerNorm
# t5的layernorm都在前面,因此重新定义了下
layer = T5Layer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size, self.hidden_act, is_dropout=self.is_dropout,
conditional_size=self.conditional_size, **get_kw(BertLayer, kwargs))
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) for _ in range(self.num_hidden_layers)])
# 把第二层后的相对位置编码的权重绑定到第一层上,变相实现仅由第一层计算
for i in range(1, self.num_hidden_layers):
self.encoderLayer[i].multiHeadAttention.relative_positions_encoding.weight = self.encoderLayer[0].multiHeadAttention.relative_positions_encoding.weight
self.final_layer_norm = LayerNorm(self.hidden_size, eps=1e-12, conditional_size=self.conditional_size, bias=False, mode='rmsnorm')
self.dropout = nn.Dropout(self.dropout_rate)
def apply_final_layers(self, inputs):
hidden_states = super().apply_final_layers(inputs)
return self.dropout(self.final_layer_norm([hidden_states]))
def load_variable(self, state_dict, name, prefix=''):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {'encoder.embed_tokens.weight', 'shared.weight'}:
return self.load_embeddings(variable)
else:
return variable
def variable_mapping(self, prefix=''):
# 查看check_point发现'shared.weight'
mapping = {f'{prefix}embeddings.word_embeddings.weight': 'encoder.embed_tokens.weight',
f'{prefix}encoderLayer.0.multiHeadAttention.relative_positions_encoding.weight': 'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
f'{prefix}final_layer_norm.weight': 'encoder.final_layer_norm.weight'}
for i in range(self.num_hidden_layers):
mapping.update(
{
f'{prefix}encoderLayer.{i}.multiHeadAttention.q.weight': f'encoder.block.{i}.layer.0.SelfAttention.q.weight',
f'{prefix}encoderLayer.{i}.multiHeadAttention.k.weight': f'encoder.block.{i}.layer.0.SelfAttention.k.weight',
f'{prefix}encoderLayer.{i}.multiHeadAttention.v.weight': f'encoder.block.{i}.layer.0.SelfAttention.v.weight',
f'{prefix}encoderLayer.{i}.multiHeadAttention.o.weight': f'encoder.block.{i}.layer.0.SelfAttention.o.weight',
f'{prefix}encoderLayer.{i}.layerNorm1.weight': f'encoder.block.{i}.layer.0.layer_norm.weight',
f'{prefix}encoderLayer.{i}.feedForward.outputDense.weight': f'encoder.block.{i}.layer.1.DenseReluDense.wo.weight',
f'{prefix}encoderLayer.{i}.layerNorm2.weight': f'encoder.block.{i}.layer.1.layer_norm.weight',
})
if self.version.endswith('t5.1.0'):
mapping.update({f'{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight': f'encoder.block.{i}.layer.1.DenseReluDense.wi.weight'})
elif self.version.endswith('t5.1.1'):
mapping.update({f'{prefix}encoderLayer.{i}.feedForward.intermediateDense.weight': f'encoder.block.{i}.layer.1.DenseReluDense.wi_0.weight',
f'{prefix}encoderLayer.{i}.feedForward.intermediateDense1.weight': f'encoder.block.{i}.layer.1.DenseReluDense.wi_1.weight'})
return mapping
class T5_Decoder(Decoder):
@insert_arguments(version='t5.1.0')
def __init__(self, *args, **kwargs):
kwargs.update({'p_bias': 't5_relative', 'relative_attention_num_buckets': kwargs.get('relative_attention_num_buckets'), 'version': self.version,
'bias': False, 'norm_mode': 'rmsnorm'}) # p_bias来控制embedding阶段无pos_embedding,t5不使用bias,并且使用rmsnorm
super().__init__(*args, **kwargs)
del self.embeddings.layerNorm
# t5的layernorm都在前面,因此重新定义了下
layer = T5Layer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size, self.hidden_act, is_dropout=self.is_dropout,
conditional_size=self.conditional_size, is_decoder=True, **get_kw(BertLayer, kwargs))
self.decoderLayer = nn.ModuleList([copy.deepcopy(layer) for _ in range(self.num_hidden_layers)])
# 把第二层后的相对位置编码的权重绑定到第一层上,变相实现仅由第一层计算
for i in range(1, self.num_hidden_layers):
self.decoderLayer[i].multiHeadAttention.relative_positions_encoding.weight = self.decoderLayer[0].multiHeadAttention.relative_positions_encoding.weight
self.final_layer_norm = LayerNorm(self.hidden_size, eps=1e-12, conditional_size=self.conditional_size, bias=False, mode='rmsnorm')
self.dropout = nn.Dropout(self.dropout_rate)
def apply_final_layers(self, inputs):
inputs[0][1] = self.dropout(self.final_layer_norm([inputs[0][1]])) # 在转logit前把最后一层的hidden_states加layernorm
return super().apply_final_layers(inputs)
def load_variable(self, state_dict, name, prefix=''):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {f'decoder.embed_tokens.weight', 'lm_head.weight', 'shared.weight'}:
return self.load_embeddings(variable)
else:
return variable
def variable_mapping(self, prefix=''):
# 查看check_point发现'shared.weight'
mapping = {f'{prefix}embeddings.word_embeddings.weight': 'decoder.embed_tokens.weight',
f'{prefix}decoderLayer.0.multiHeadAttention.relative_positions_encoding.weight': 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight',
f'{prefix}final_layer_norm.weight': 'decoder.final_layer_norm.weight',
f'{prefix}final_dense.weight': 'lm_head.weight'}
for i in range(self.num_hidden_layers):
mapping.update(
{
f'{prefix}decoderLayer.{i}.multiHeadAttention.q.weight': f'decoder.block.{i}.layer.0.SelfAttention.q.weight',
f'{prefix}decoderLayer.{i}.multiHeadAttention.k.weight': f'decoder.block.{i}.layer.0.SelfAttention.k.weight',
f'{prefix}decoderLayer.{i}.multiHeadAttention.v.weight': f'decoder.block.{i}.layer.0.SelfAttention.v.weight',
f'{prefix}decoderLayer.{i}.multiHeadAttention.o.weight': f'decoder.block.{i}.layer.0.SelfAttention.o.weight',
f'{prefix}decoderLayer.{i}.layerNorm1.weight': f'decoder.block.{i}.layer.0.layer_norm.weight',
f'{prefix}decoderLayer.{i}.crossAttention.q.weight': f'decoder.block.{i}.layer.1.EncDecAttention.q.weight',
f'{prefix}decoderLayer.{i}.crossAttention.k.weight': f'decoder.block.{i}.layer.1.EncDecAttention.k.weight',
f'{prefix}decoderLayer.{i}.crossAttention.v.weight': f'decoder.block.{i}.layer.1.EncDecAttention.v.weight',
f'{prefix}decoderLayer.{i}.crossAttention.o.weight': f'decoder.block.{i}.layer.1.EncDecAttention.o.weight',
f'{prefix}decoderLayer.{i}.layerNorm3.weight': f'decoder.block.{i}.layer.1.layer_norm.weight',
f'{prefix}decoderLayer.{i}.feedForward.outputDense.weight': f'decoder.block.{i}.layer.2.DenseReluDense.wo.weight',
f'{prefix}decoderLayer.{i}.layerNorm2.weight': f'decoder.block.{i}.layer.2.layer_norm.weight',
})
if self.version.endswith('t5.1.0'):
mapping.update({f'{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight': f'decoder.block.{i}.layer.2.DenseReluDense.wi.weight'})
elif self.version.endswith('t5.1.1'):
mapping.update({f'{prefix}decoderLayer.{i}.feedForward.intermediateDense.weight': f'decoder.block.{i}.layer.2.DenseReluDense.wi_0.weight',
f'{prefix}decoderLayer.{i}.feedForward.intermediateDense1.weight': f'decoder.block.{i}.layer.2.DenseReluDense.wi_1.weight'})
return mapping
class T5(Transformer):
"""Google的T5模型(Encoder-Decoder)
"""
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, *args, tie_emb_src_tgt_weight=True, **kwargs):
super(T5, self).__init__(*args, **kwargs)
self.tie_emb_src_tgt_weight = tie_emb_src_tgt_weight
# encoder
self.encoder = T5_Encoder(*args, **kwargs)
self.encoder.build(**kwargs)
# decoder
self.decoder = T5_Decoder(*args, **kwargs)
self.decoder.build(**kwargs)
def load_variable(self, state_dict, name, prefix=''):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {'shared.weight', 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'}:
return self.load_embeddings(variable)
else:
return variable
def variable_mapping(self, prefix=''):
mapping = self.encoder.variable_mapping(prefix='encoder.')
mapping.update(self.decoder.variable_mapping(prefix='decoder.'))
if self.tie_emb_src_tgt_weight:
mapping.update({'encoder.embeddings.word_embeddings.weight': 'shared.weight',
'decoder.embeddings.word_embeddings.weight': 'shared.weight'})
return mapping
class GPT(LM_Mask, BERT):
"""构建GPT模型
链接:https://github.com/openai/finetune-transformer-lm
"""
@insert_arguments(final_activation='softmax')
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, max_position, **kwargs):
"""GPT的embedding是token、position、segment三者embedding之和,跟BERT的主要区别是三者相加之后没有加LayerNormalization层。
使用LM_Mask实现预训练ckpt中的bias参数,最后的全连接层由于和embedding层权重一致,因此直接从word_embedding取
"""
super(GPT, self).__init__(max_position, **kwargs)
del self.embeddings.layerNorm
self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
self.dense.weight = self.embeddings.word_embeddings.weight
self.final_activation = get_activation(self.final_activation)
def apply_final_layers(self, inputs):
hidden_state = super().apply_final_layers(inputs)
logit = self.dense(hidden_state)
return self.final_activation(logit)
def load_variable(self, state_dict, name):
return super(GPT, self).load_variable(state_dict, name, prefix='gpt')
def variable_mapping(self):
"""映射到GPT权重格式
"""
mapping = super(GPT, self).variable_mapping(prefix='gpt')
return mapping
class GPT2(LM_Mask, BERT):
"""构建GPT模型
链接:https://github.com/openai/finetune-transformer-lm
"""
@insert_arguments(final_activation='softmax')
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, max_position, **kwargs):
"""GPT2的embedding是token、position两者embedding之和
1、跟BERT的主要区别是三者相加之后没有加LayerNormalization层。
2、bert的layernorm是在attn/ffc之后,OpenAi-gpt2是在之前。
使用LM_Mask实现预训练ckpt中的bias参数,最后的全连接层由于和embedding层权重一致,因此直接从word_embedding取
"""
super(GPT2, self).__init__(max_position, **kwargs)
del self.embeddings.layerNorm
layer = self.Gpt2Layer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size, self.hidden_act, is_dropout=self.is_dropout, conditional_size=self.conditional_size)
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) if layer_id in self.keep_hidden_layers else Identity() for layer_id in range(self.num_hidden_layers)])
self.LayerNormFinal = LayerNorm(self.hidden_size, eps=1e-12, conditional_size=self.conditional_size)
self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
self.dense.weight = self.embeddings.word_embeddings.weight
self.final_activation = get_activation(self.final_activation)
def apply_final_layers(self, inputs):
hidden_state = super().apply_final_layers(inputs)
logit = self.dense(self.LayerNormFinal([hidden_state]))
return self.final_activation(logit)
def load_variable(self, state_dict, name):
return super(GPT2, self).load_variable(state_dict, name, prefix='gpt2')
def variable_mapping(self):
"""映射到GPT权重格式
"""
mapping = super(GPT2, self).variable_mapping(prefix='gpt2')
mapping.update({'LayerNormFinal.weight': 'gpt2.LayerNormFinal.weight',
'LayerNormFinal.bias': 'gpt2.LayerNormFinal.bias'})
return mapping
class Gpt2Layer(BertLayer):
'''未定义在layer.py中是因为该层针对gpt2_mlm模型,不可复用
顺序:LN --> Att --> Add --> LN --> FFN --> Add
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
# bert的layernorm是在attn/ffc之后,Openai-gpt2是在之前
x = self.layerNorm1((hidden_states, conditional_emb))
self_attn_output = self.multiHeadAttention(x, attention_mask)
hidden_states = hidden_states + self.dropout1(self_attn_output)
x = self.layerNorm2((hidden_states, conditional_emb))
ffn_output = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(ffn_output)
return hidden_states
class GPT2_ML(LM_Mask, BERT):
"""构建GPT2_ML模型
链接: https://github.com/imcaspar/gpt2-ml
注意:GPT2_ML虽然号称GPT2,但是它的结构其实更接近GPT,它自称GPT2的原因大概是因为它开源的版本参数量达到了GPT2的15亿参数。
看完ckpt中的key,和GPT的区别是embedding后也有layernorm,和bert的区别是第一个跳跃链接是在layernorm前,bert是在之后
"""
@insert_arguments(final_activation='softmax')
@delete_arguments('with_pool', 'with_mlm', 'with_nsp')
def __init__(self, max_position, **kwargs):
super().__init__(max_position, **kwargs)
layer = self.Gpt2MlLayer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size, self.hidden_act, is_dropout=self.is_dropout, conditional_size=self.conditional_size)
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) if layer_id in self.keep_hidden_layers else Identity() for layer_id in range(self.num_hidden_layers)])
self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
self.dense.weight = self.embeddings.word_embeddings.weight
self.final_activation = get_activation(self.final_activation)
def apply_final_layers(self, inputs):
hidden_state = super().apply_final_layers(inputs)
logit = self.dense(hidden_state)
return self.final_activation(logit)
def load_variable(self, state_dict, name):
return super(GPT2_ML, self).load_variable(state_dict, name, prefix='gpt2_ml')
def variable_mapping(self):
"""映射到GPT2权重格式
"""
mapping = super(GPT2_ML, self).variable_mapping(prefix='gpt2_ml')
return mapping
class Gpt2MlLayer(BertLayer):
'''未定义在layer.py中是因为该层针对gpt2_mlm模型,不可复用
顺序:Att --> Add --> LN --> FFN --> Add --> LN
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_attn_output = self.multiHeadAttention(hidden_states, attention_mask)
hidden_states = hidden_states + self.dropout1(self_attn_output)
x = self.layerNorm1((hidden_states, conditional_emb))
# bert的跳跃连接是在layerNorm之后,gpt2_ml是在layerNorm之前
ffn_output = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(ffn_output)
hidden_states = self.layerNorm2((hidden_states, conditional_emb))
return hidden_states
class Transformer_XL(BERT):
'''构建transformer-xl模型, 已加载
项目: https://github.com/kimiyoung/transformer-xl
不同点:
1) 简化了原有的AdaptiveEmbedding(可选)和未使用ProjectedAdaptiveLogSoftmax, 直接输出last_hidden_state
2) mems修改了transformer中初始化为zero_tensor, 改为包含最后一层, 原项目初始化为empty_tensor
3) SinusoidalPositionEncoding一般是sincos间隔排列, 这里是先sin后cos
4) attention_mask在multi_attn中使用中使用1e30来替代原来的1000
'''
@delete_arguments('with_pool', 'with_nsp', 'with_mlm')
@insert_arguments(with_lm=False)
def __init__(self, *args, mem_len=0, same_length=False, clamp_len=-1, **kwargs):
# p_bias来控制embedding阶段无pos_embedding
kwargs.update({'p_bias': 'other_relative'})
super().__init__(*args, **kwargs)
self.mem_len, self.same_length, self.clamp_len = mem_len, same_length, clamp_len
self.attn_type = kwargs.get('attn_type', 0)
# embedding
if kwargs.get('adaptive_embedding'):
cutoffs, div_val, sample_softmax = kwargs.get('cutoffs', []), kwargs.get('div_val', 1), kwargs.get('sample_softmax', False)
self.embeddings = AdaptiveEmbedding(self.vocab_size, self.embedding_size, self.hidden_size, cutoffs, div_val, sample_softmax, **get_kw(AdaptiveEmbedding, kwargs))
else:
self.embeddings = nn.Embedding(self.vocab_size, self.embedding_size)
self.pos_embeddings = XlnetPositionsEncoding(self.embedding_size)
self.dropout = nn.Dropout(self.dropout_rate)
# 每层自己的r_w_bias和r_r_bias,还是公用
if not kwargs.get('untie_r'):
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局内容偏置
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局位置偏置
if self.segment_vocab_size > 0:
self.r_s_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局segment偏置
else:
self.r_w_bias, self.r_r_bias = None, None
self.r_s_bias = None
# transformer block
layer = XlnetLayer(self.hidden_size, self.num_attention_heads, self.dropout_rate, self.attention_probs_dropout_prob, self.intermediate_size,
self.hidden_act, is_dropout=self.is_dropout, conditional_size=self.conditional_size, r_w_bias=self.r_w_bias, r_r_bias=self.r_r_bias,
r_s_bias=None, **get_kw(BertLayer, kwargs))
self.encoderLayer = nn.ModuleList([copy.deepcopy(layer) if layer_id in self.keep_hidden_layers else Identity() for layer_id in range(self.num_hidden_layers)])
# 映射
if self.with_lm:
self.dense = nn.Linear(self.hidden_size, self.vocab_size, bias=True)
def init_mems(self, bsz):
'''初始化mems, 用于记忆mlen的各层隐含层状态
'''
if isinstance(self.mem_len, (int, float)) and (self.mem_len > 0):
mems = []
param = next(self.parameters())
for _ in range(self.num_hidden_layers+1):
empty = torch.zeros(bsz, self.mem_len, self.hidden_size, dtype=param.dtype, device=param.device)
mems.append(empty)
return mems
else:
return None
def _update_mems(self, hids, mlen, qlen):
'''更新mems
'''
# does not deal with None
if self.mems is None:
return None
# mems is not None
assert len(hids) == len(self.mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
cat = torch.cat([self.mems[i], hids[i]], dim=1)
new_mems.append(cat[:, beg_idx:end_idx].detach())
self.mems = new_mems
def relative_positional_encoding(self, qlen, klen, device):
# 生成pos_emb, 这里使用sincos的位置编码,为了和xlnet入参一致
pos_seq = torch.arange(klen-1, -1, -1.0, device=device, dtype=torch.long)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.dropout(self.pos_embeddings(pos_seq)) # 用word_emb的dropout
return pos_emb
def create_mask(self, word_emb, qlen, klen, mlen):
# 修改attention_mask, mlen可以全部访问,q_len只能访问<=t时刻的, mask和Unilm类似,但是Unilm是靠segement_ids来控制
if self.same_length: # 只能访问前面固定长度
all_ones = word_emb.new_ones(qlen, klen)
mask_len = klen - self.mem_len
mask_shift_len = qlen - mask_len if mask_len > 0 else qlen
attention_mask = 1-(torch.triu(all_ones, 1+mlen) + torch.tril(all_ones, -mask_shift_len)).byte() # -1
else:
attention_mask = torch.tril(word_emb.new_ones(qlen, klen), diagonal=mlen).byte() # [q_len, k_len], 下三角为1矩阵
attention_mask = attention_mask[None, None, :, :]
return attention_mask
def apply_embeddings(self, inputs):
'''接受的inputs输入: [token_ids, segment_ids], 暂不支持条件LayerNorm输入
'''
self.mems = self.init_mems(inputs[0].size(0)) # 生成mems
# 精简后embeddings中只计算word_emdedding
word_emb = self.dropout(self.embeddings(inputs[0]))
index_ = 1
btz, qlen = inputs[0].shape[:2] # query长度
mlen = self.mems[0].size(1) if self.mems is not None else 0
klen = mlen + qlen
# 相对位置编码
pos_emb = self.relative_positional_encoding(qlen, klen, word_emb.device)
# segment embedding
if self.segment_vocab_size > 0:
segment_ids = inputs[index_]
if mlen > 0:
mem_pad = torch.zeros([btz, mlen], dtype=torch.long, device=word_emb.device)
cat_ids = torch.cat([mem_pad, segment_ids], dim=1)
else:
cat_ids = segment_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
segment_ids = (segment_ids[:, :, None] != cat_ids[:, None]).long()
index_ += 1
else:
segment_ids = None
if self.attn_type in {'uni', 0}: # 兼容transformer_xl的设置: 0
attention_mask = self.create_mask(word_emb, qlen, klen, mlen)
elif self.attn_type == 'bi':
attention_mask = (inputs[0] != self.token_pad_ids).long().unsqueeze(1).unsqueeze(2)
non_tgt_mask = torch.eye(qlen).to(attention_mask)[None, None, :, :]
non_tgt_mask = ((1 - attention_mask - non_tgt_mask) <= 0).long()
return [word_emb, segment_ids, pos_emb, non_tgt_mask, None]
def apply_main_layers(self, inputs):
hidden_states, segment_ids, pos_emb, attention_mask, conditional_emb = inputs[:5]
encoded_layers = [hidden_states] # 添加embedding的输出
layer_inputs = [hidden_states, segment_ids, pos_emb, attention_mask, None, conditional_emb]
for i, layer_module in enumerate(self.encoderLayer):
mems_i = None if self.mems is None else self.mems[i]
layer_inputs[-2] = mems_i
layer_inputs = self.apply_on_layer_begin(i, layer_inputs)
hidden_states = layer_module(*layer_inputs)
layer_inputs[0] = hidden_states
layer_inputs = self.apply_on_layer_end(i, layer_inputs)
encoded_layers.append(hidden_states)
# 原实现中word_emb, pos_emb和core_out(hidden_states)使用同一个dropout
hidden_states = self.dropout(hidden_states)
qlen = inputs[0].size(1) # query长度
mlen = self.mems[0].size(0) if self.mems is not None else 0
self._update_mems(encoded_layers, mlen, qlen)
if not self.output_all_encoded_layers:
# 不返回所有层,即返回顶层
encoded_layers = encoded_layers[:1] + [hidden_states]
return [encoded_layers, conditional_emb]
def load_variable(self, state_dict, name, prefix=''):
# 这里由于预训练模型使用了AdapterEmbedding,因此暂不支持
if (self.keep_tokens is not None) or (self.compound_tokens is not None):
raise ValueError('Custom keep_tokens and compound_tokens is not yet supported in Transformer_XL')
return state_dict[name]
def variable_mapping(self, prefix=''):
return {k:k for k, v in self.named_parameters()}
class XLNET(Transformer_XL):
'''构建xlnet模型, 这里做了简化, 只用来finetune, 即没有perm_mask, target_mapping这些输入
接受的inputs输入: [token_ids, segment_ids]
'''
def __init__(self, *args, bi_data=False, **kwargs):
self.attn_type = kwargs.get('attn_type', 'bi')
self.bi_data = bi_data
kwargs['rel_shift_opt'] = 'xlnet'
super().__init__(*args, **kwargs)
def relative_positional_encoding(self, qlen, klen, device):
# 生成pos_emb, 这里使用sincos的位置编码, transformer_xl里面有-1
if self.attn_type == 'bi':
beg, end = klen, -qlen
elif self.attn_type == "uni":
beg, end = klen, -1
else:
raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
# 前向的emb
pos_seq = torch.arange(beg, end, -1.0, device=device, dtype=torch.long)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
fwd_pos_emb = self.pos_embeddings(pos_seq)
# 双向数据
if self.bi_data:
pos_seq = torch.arange(-beg, -end, -1.0, device=device, dtype=torch.long)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
bwd_pos_emb = self.pos_embeddings(pos_seq)
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=0)
else:
pos_emb = fwd_pos_emb
pos_emb = self.dropout(pos_emb) # 用word_emb的dropout
return pos_emb
def apply_final_layers(self, inputs):
hidden_state = super().apply_final_layers(inputs)
if self.with_lm:
return [hidden_state, self.dense(hidden_state)]
else:
return hidden_state
def load_variable(self, state_dict, name, prefix='transformer'):
"""加载单个变量的函数
"""
variable = state_dict[name]
if name in {f'{prefix}.word_embedding.weight', 'lm_loss.weight', 'lm_loss.bias'}:
return self.load_embeddings(variable)
elif re.search('rel_attn\.(q|k|v|r)$', name):
return variable.reshape(variable.shape[0], -1).T
# elif re.search('rel_attn\.(o|seg_embed)$', name):
elif re.search('rel_attn\.(o)$', name):
return variable.reshape(variable.shape[0], -1)
else:
return variable
def variable_mapping(self, prefix='transformer'):
mapping = {
'embeddings.weight': f'{prefix}.word_embedding.weight',
'dense.weight': 'lm_loss.weight',
'dense.bias': 'lm_loss.bias',
}
for i in range(self.num_hidden_layers):
prefix_i = f'{prefix}.layer.%d.' % i
mapping.update({f'encoderLayer.{i}.multiHeadAttention.q.weight': prefix_i + 'rel_attn.q',
f'encoderLayer.{i}.multiHeadAttention.k.weight': prefix_i + 'rel_attn.k',
f'encoderLayer.{i}.multiHeadAttention.v.weight': prefix_i + 'rel_attn.v',
f'encoderLayer.{i}.multiHeadAttention.o.weight': prefix_i + 'rel_attn.o',
f'encoderLayer.{i}.multiHeadAttention.r.weight': prefix_i + 'rel_attn.r',
f'encoderLayer.{i}.multiHeadAttention.r_r_bias': prefix_i + 'rel_attn.r_r_bias',
f'encoderLayer.{i}.multiHeadAttention.r_s_bias': prefix_i + 'rel_attn.r_s_bias',
f'encoderLayer.{i}.multiHeadAttention.r_w_bias': prefix_i + 'rel_attn.r_w_bias',
# f'encoderLayer.{i}.multiHeadAttention.seg_embed.weight': prefix_i + 'rel_attn.seg_embed',
f'encoderLayer.{i}.multiHeadAttention.seg_embed': prefix_i + 'rel_attn.seg_embed',
f'encoderLayer.{i}.layerNorm1.weight': prefix_i + 'rel_attn.layer_norm.weight',
f'encoderLayer.{i}.layerNorm1.bias': prefix_i + 'rel_attn.layer_norm.bias',
f'encoderLayer.{i}.feedForward.intermediateDense.weight': prefix_i + 'ff.layer_1.weight',
f'encoderLayer.{i}.feedForward.intermediateDense.bias': prefix_i + 'ff.layer_1.bias',
f'encoderLayer.{i}.feedForward.outputDense.weight': prefix_i + 'ff.layer_2.weight',
f'encoderLayer.{i}.feedForward.outputDense.bias': prefix_i + 'ff.layer_2.bias',
f'encoderLayer.{i}.layerNorm2.weight': prefix_i + 'ff.layer_norm.weight',
f'encoderLayer.{i}.layerNorm2.bias': prefix_i + 'ff.layer_norm.bias'
})
return mapping
def build_transformer_model(
config_path=None,
checkpoint_path=None,
model='bert',
application='encoder',
**kwargs
):
"""根据配置文件构建模型,可选加载checkpoint权重
"""
configs = {}
if config_path is not None:
configs.update(json.load(open(config_path)))
configs.update(kwargs)
if 'max_position' not in configs:
configs['max_position'] = configs.get('max_position_embeddings', 512)
if 'dropout_rate' not in configs:
configs['dropout_rate'] = configs.get('hidden_dropout_prob')
if 'segment_vocab_size' not in configs:
configs['segment_vocab_size'] = configs.get('type_vocab_size', 2)
models = {
'bert': BERT,
'roberta': BERT,
'albert': ALBERT,
'albert_unshared': ALBERT_Unshared,
'nezha': NEZHA,
'roformer': RoFormer,
'roformer_v2': RoFormerV2,
'gau_alpha': GAU_alpha,
'electra': ELECTRA,
'encoder': Encoder,
'decoder': Decoder,
'transformer': Transformer,
'bart': BART,
'gpt': GPT,
'gpt2': GPT2,
'gpt2_ml': GPT2_ML,
't5': T5,
't5_encoder': T5_Encoder,
't5_decoder': T5_Decoder,
't5.1.0': T5,
't5.1.0_encoder': T5_Encoder,
't5.1.0_decoder': T5_Decoder,
't5.1.1': T5,
't5.1.1_encoder': T5_Encoder,
't5.1.1_decoder': T5_Decoder,
'mt5.1.1': T5,
'mt5.1.1_encoder': T5_Encoder,
'mt5.1.1_decoder': T5_Decoder,
'transformer_xl': Transformer_XL,
'xlnet': XLNET,
}
if isinstance(model, str): # string表示使用自带的模型
MODEL = models[model.lower()]
if model.endswith('t5.1.1'):
configs['version'] = model
elif isinstance(model, type) and issubclass(model, BERT_BASE): # nn.Module表示使用自定义的模型:
MODEL = model
else:
raise ValueError('"model" args type should be string or nn.Module')
application = application.lower()
if application in ['lm', 'unilm'] and model in ['electra', 't5', ]:
raise ValueError(f'"{model}" model can not be used as "{application}" application.\n')
if application == 'lm':
MODEL = extend_with_language_model(MODEL)
elif application == 'unilm':
MODEL = extend_with_unified_language_model(MODEL)
transformer = MODEL(**configs)
transformer.build(**configs)
transformer.apply(transformer.init_model_weights) # 初始化权重
if checkpoint_path is not None:
transformer.load_weights_from_pytorch_checkpoint(checkpoint_path)
transformer.configs = configs
return transformer
\ No newline at end of file
from torch.optim.lr_scheduler import LambdaLR
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
"""带warmup的schedule, 源自transformers包optimization.py中
参数
num_warmup_steps:
需要warmup的步数, 一般为 num_training_steps * warmup_proportion(warmup的比例, 建议0.05-0.15)
num_training_steps:
总的训练步数, 一般为 train_batches * num_epoch
"""
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
return LambdaLR(optimizer, lr_lambda, last_epoch)
def extend_with_exponential_moving_average(model, decay=0.999):
class ExponentialMovingAverage():
''' 模型权重的指数滑动平均, 不参加梯度更新,只是记录滑动平均的参数,给预测使用
注意区别于类似adam一类的自适应学习率优化器, 针对一阶二阶梯度的指数滑动平均, 两者完全不同
例子:
# 初始化
ema = ExponentialMovingAverage(model, 0.999)
# 训练过程中, 更新完参数后, 同步update ema_weights weights
def train():
optimizer.step()
ema.step()
# eval前, 调用apply_ema_weights(); eval之后, restore_raw_weights()恢复原来模型的参数
def evaluate():
ema.apply_ema_weights()
# evaluate
# 如果想保存ema后的模型, 请在restore方法之前调用torch.save()
ema.restore_raw_weights()
'''
def __init__(self, model, decay):
self.model = model
self.decay = decay
# 保存ema权重(当前step的每一层的滑动平均权重)
self.ema_weights = {}
# 在进行evaluate的时候, 保存原始的模型权重, 当执行完evaluate后, 从ema权重恢复到原始权重
self.model_weights = {}
# 初始化ema_weights为model_weights
for name, param in self.model.named_parameters():
if param.requires_grad:
self.ema_weights[name] = param.data.clone()
def step(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.ema_weights
new_average = (1.0 - self.decay) * param.data + self.decay * self.ema_weights[name]
self.ema_weights[name] = new_average.clone()
def apply_ema_weights(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.ema_weights
self.model_weights[name] = param.data
param.data = self.ema_weights[name]
def restore_raw_weights(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.model_weights
param.data = self.model_weights[name]
self.model_weights = {}
return ExponentialMovingAverage(model, decay)
\ No newline at end of file
#! -*- coding: utf-8 -*-
# 其他代码合
import unicodedata
import six
import numpy as np
import re
import torch
from torch.nn.utils.rnn import pad_sequence
import time
import sys
import collections
import torch.nn as nn
from torch.utils.data import Dataset, IterableDataset
import math
import gc
import inspect
import json
import torch.nn.functional as F
import random
import warnings
import os
is_py2 = six.PY2
if not is_py2:
basestring = str
def take_along_dim(input_tensor, indices, dim=None):
'''兼容部分低版本pytorch没有torch.take_along_dim
'''
if torch.__version__ >= '1.9.0':
return torch.take_along_dim(input_tensor, indices, dim)
else:
# 该逻辑仅在少量数据上测试,如有bug,欢迎反馈
if dim is None:
res = input_tensor.flatten()[indices]
else:
res = np.take_along_axis(input_tensor.cpu().numpy(), indices.cpu().numpy(), axis=dim)
res = torch.from_numpy(res).to(input_tensor.device)
# assert res.equal(torch.take_along_dim(input_tensor, indices, dim))
return res
def is_string(s):
"""判断是否是字符串
"""
return isinstance(s, basestring)
def truncate_sequences(maxlen, indices, *sequences):
"""截断总长度至不超过maxlen
"""
sequences = [s for s in sequences if s]
if not isinstance(indices, (list, tuple)):
indices = [indices] * len(sequences)
while True:
lengths = [len(s) for s in sequences]
if sum(lengths) > maxlen:
i = np.argmax(lengths)
sequences[i].pop(indices[i])
else:
return sequences
def text_segmentate(text, maxlen, seps='\n', strips=None, truncate=True):
"""将文本按照标点符号划分为若干个短句
truncate: True表示标点符号切分后仍然超长时, 按照maxlen硬截断分成若干个短句
"""
text = text.strip().strip(strips)
if seps and len(text) > maxlen:
pieces = text.split(seps[0])
text, texts = '', []
for i, p in enumerate(pieces):
if text and p and len(text) + len(p) > maxlen - 1:
texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate))
text = ''
if i + 1 == len(pieces):
text = text + p
else:
text = text + p + seps[0]
if text:
texts.extend(text_segmentate(text, maxlen, seps[1:], strips, truncate))
return texts
elif truncate and (not seps) and (len(text) > maxlen):
# 标点符号用完,仍然超长,且设置了truncate=True
return [text[i*maxlen:(i+1)*maxlen] for i in range(0, int(np.ceil(len(text)/maxlen)))]
else:
return [text]
def merge_segmentate(sequences, maxlen, sep=''):
'''把m个句子合并成不超过maxlen的n个句子, 主要用途是合并碎句子
'''
sequences_new = []
text = ''
for t in sequences:
if text and len(text + sep + t) <= maxlen:
text = text + sep + t
elif text:
sequences_new.append(text)
text = t
elif len(t) < maxlen: # text为空
text = t
else:
sequences_new.append(t)
text = ''
if text:
sequences_new.append(text)
return sequences_new
def text_augmentation(texts, noise_dict=None, noise_len=0, noise_p=0.0, skip_words=None, strategy='random', allow_dup=True):
'''简单的EDA策略, 增删改
texts: 需要增强的文本/文本list
noise_dict: 噪音数据, 元素为str的list, tuple, set
noise_len: 噪音长度, 优先试用
noise_p: 噪音比例
skip_words: 跳过的短语, string/list
strategy: 修改的策略, 包含增insert, 删delete, 改replace, 随机random
allow_dup: 是否允许同一个位置多次EDA
'''
def insert(text, insert_idx, noise_dict):
text = list(text)
for i in insert_idx:
text[i] = text[i] + random.choice(noise_dict)
return ''.join(text)
def delete(text, delete_idx):
text = list(text)
for i in delete_idx:
text[i] = ''
return ''.join(text)
def replace(text, replace_idx, noise_dict):
text = list(text)
for i in replace_idx:
text[i] = random.choice(noise_dict)
return ''.join(text)
def search(pattern, sequence, keep_last=True):
"""从sequence中寻找子串pattern, 返回符合pattern的id集合
"""
n = len(pattern)
pattern_idx_set = set()
for i in range(len(sequence)):
if sequence[i:i + n] == pattern:
pattern_idx_set = pattern_idx_set.union(set(range(i, i+n))) if keep_last else pattern_idx_set.union(set(range(i, i+n-1)))
return pattern_idx_set
if (noise_len==0) and (noise_p==0):
return texts
assert strategy in {'insert', 'delete', 'replace', 'random'}, 'EDA strategy only support insert, delete, replace, random'
if isinstance(texts, str):
texts = [texts]
if skip_words is None:
skip_words = []
elif isinstance(skip_words, str):
skip_words = [skip_words]
for id, text in enumerate(texts):
sel_len = noise_len if noise_len > 0 else int(len(text)*noise_p) # 噪声长度
skip_idx = set() # 不能修改的idx区间
for item in skip_words:
# insert时最后一位允许插入
skip_idx = skip_idx.union(search(item, text, strategy!='insert'))
sel_idxs = [i for i in range(len(text)) if i not in skip_idx] # 可供选择的idx区间
sel_len = sel_len if allow_dup else min(sel_len, len(sel_idxs)) # 无重复抽样需要抽样数小于总样本
if (sel_len == 0) or (len(sel_idxs) == 0): # 如果不可采样则跳过
continue
sel_idx = np.random.choice(sel_idxs, sel_len, replace=allow_dup)
if strategy == 'insert':
texts[id] = insert(text, sel_idx, noise_dict)
elif strategy == 'delete':
texts[id] = delete(text, sel_idx)
elif strategy == 'replace':
texts[id] = replace(text, sel_idx, noise_dict)
elif strategy == 'random':
if random.random() < 0.333:
skip_idx = set() # 不能修改的idx区间
for item in skip_words:
# insert时最后一位允许插入
skip_idx = skip_idx.union(search(item, text, keep_last=False))
texts[id] = insert(text, sel_idx, noise_dict)
elif random.random() < 0.667:
texts[id] = delete(text, sel_idx)
else:
texts[id] = replace(text, sel_idx, noise_dict)
return texts if len(texts) > 1 else texts[0]
def lowercase_and_normalize(text, never_split=()):
"""转小写,并进行简单的标准化
"""
if is_py2:
text = unicode(text)
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in never_split]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
# text = text.lower()
text = unicodedata.normalize('NFD', text)
text = ''.join([ch for ch in text if unicodedata.category(ch) != 'Mn'])
return text
def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'):
"""将序列padding到同一长度
"""
if isinstance(inputs[0], (np.ndarray, list)):
if length is None:
length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0)
elif not hasattr(length, '__getitem__'):
length = [length]
slices = [np.s_[:length[i]] for i in range(seq_dims)]
slices = tuple(slices) if len(slices) > 1 else slices[0]
pad_width = [(0, 0) for _ in np.shape(inputs[0])]
outputs = []
for x in inputs:
x = x[slices]
for i in range(seq_dims):
if mode == 'post':
pad_width[i] = (0, length[i] - np.shape(x)[i])
elif mode == 'pre':
pad_width[i] = (length[i] - np.shape(x)[i], 0)
else:
raise ValueError('"mode" argument must be "post" or "pre".')
x = np.pad(x, pad_width, 'constant', constant_values=value)
outputs.append(x)
return np.array(outputs)
elif isinstance(inputs[0], torch.Tensor):
assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor'
if length is not None:
inputs = [i[:length] for i in inputs]
return pad_sequence(inputs, padding_value=value, batch_first=True)
else:
raise ValueError('"input" argument must be tensor/list/ndarray.')
def insert_arguments(**arguments):
"""装饰器,为类方法增加参数(主要用于类的__init__方法)
"""
def actual_decorator(func):
def new_func(self, *args, **kwargs):
for k, v in arguments.items():
if k in kwargs:
v = kwargs.pop(k)
setattr(self, k, v)
return func(self, *args, **kwargs)
return new_func
return actual_decorator
def delete_arguments(*arguments):
"""装饰器,为类方法删除参数(主要用于类的__init__方法)
"""
def actual_decorator(func):
def new_func(self, *args, **kwargs):
for k in arguments:
if k in kwargs:
raise TypeError(
'%s got an unexpected keyword argument \'%s\'' %
(self.__class__.__name__, k)
)
return func(self, *args, **kwargs)
return new_func
return actual_decorator
class Progbar(object):
"""Displays a progress bar.
# Arguments
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over time. Metrics in this list
will be displayed as-is. All others will be averaged
by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
"""
def __init__(self, target, width=30, verbose=1, interval=0.05,
stateful_metrics=None):
self.target = target
self.width = width
self.verbose = verbose
self.interval = interval
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules)
self._total_width = 0
self._seen_so_far = 0
self._values = collections.OrderedDict()
self._start = time.time()
self._last_update = 0
def update(self, current, values=None):
"""Updates the progress bar.
# Arguments
current: Index of current step.
values: List of tuples:
`(name, value_for_last_step)`.
If `name` is in `stateful_metrics`,
`value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
"""
values = values or []
for k, v in values:
if k not in self.stateful_metrics:
if k not in self._values:
self._values[k] = [v * (current - self._seen_so_far),
current - self._seen_so_far]
else:
self._values[k][0] += v * (current - self._seen_so_far)
self._values[k][1] += (current - self._seen_so_far)
else:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self._values[k] = [v, 1]
self._seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self._start)
if self.verbose == 1:
if (now - self._last_update < self.interval and
self.target is not None and current < self.target):
return
prev_total_width = self._total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%d [' % (numdigits, self.target)
bar = barstr % current
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self._total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self._start) / current
else:
time_per_unit = 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = ('%d:%02d:%02d' %
(eta // 3600, (eta % 3600) // 60, eta % 60))
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
else:
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
else:
info += ' %.0fus/step' % (time_per_unit * 1e6)
for k in self._values:
info += ' - %s:' % k
if isinstance(self._values[k], list):
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self._values[k]
self._total_width += len(info)
if prev_total_width > self._total_width:
info += (' ' * (prev_total_width - self._total_width))
if self.target is not None and current >= self.target:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if self.target is None or current >= self.target:
for k in self._values:
info += ' - %s:' % k
avg = np.mean(
self._values[k][0] / max(1, self._values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self._last_update = now
def add(self, n, values=None):
self.update(self._seen_so_far + n, values)
class Callback(object):
'''Callback基类
'''
def __init__(self):
pass
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
def on_epoch_begin(self, global_step, epoch, logs=None):
pass
def on_epoch_end(self, global_step, epoch, logs=None):
pass
def on_batch_begin(self, global_step, batch, logs=None):
pass
def on_batch_end(self, global_step, batch, logs=None):
pass
def on_dataloader_end(self, logs=None):
pass
class ProgbarLogger(Callback):
"""Callback that prints metrics to stdout.
# Arguments
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
# Raises
ValueError: In case of invalid `count_mode`.
"""
def __init__(self, epochs, steps, metrics, stateful_metrics=None, verbose=1):
super(ProgbarLogger, self).__init__()
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()
self.params = {'epochs': epochs, 'steps': steps, 'verbose': verbose, 'metrics': metrics}
self.verbose = verbose
self.epochs = epochs
def add_metrics(self, metrics, add_position=None):
if add_position is None:
add_position = len(self.params['metrics'])
if isinstance(metrics, str):
metrics = [metrics]
add_metrics = []
for metric in metrics:
if metric not in self.params['metrics']:
add_metrics.append(metric)
self.params['metrics'] = self.params['metrics'][:add_position] + add_metrics + self.params['metrics'][add_position:]
def on_train_begin(self, logs=None):
if self.verbose:
print('Start Training'.center(40, '='))
def on_epoch_begin(self, global_step=None, epoch=None, logs=None):
if self.verbose:
print('Epoch %d/%d' % (epoch + 1, self.epochs))
self.target = self.params['steps']
self.progbar = Progbar(target=self.target, verbose=self.verbose, stateful_metrics=self.stateful_metrics)
self.seen = 0
def on_batch_begin(self, global_step=None, batch=None, logs=None):
if self.seen < self.target:
self.log_values = []
def on_batch_end(self, global_step=None, batch=None, logs=None):
logs = logs or {}
self.seen += 1
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if self.verbose and self.seen < self.target:
self.progbar.update(self.seen, self.log_values)
def on_epoch_end(self, global_step=None, epoch=None, logs=None):
logs = logs or {}
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if self.verbose:
self.progbar.update(self.seen, self.log_values)
def on_train_end(self, logs=None):
if self.verbose:
print('Finish Training'.center(40, '='))
class EarlyStopping(Callback):
'''Stop training策略, 从keras中移植
'''
def __init__(self, monitor='loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None):
super(EarlyStopping, self).__init__()
self.monitor = monitor
self.baseline = baseline
self.patience = patience
self.verbose = verbose
self.min_delta = min_delta
self.wait = 0
self.stopped_epoch = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, fallback to auto mode.' % mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
self.monitor_op = np.greater if 'acc' in self.monitor else np.less
self.min_delta = self.min_delta if self.monitor_op == np.greater else -self.min_delta
def on_train_begin(self, logs=None):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
if self.baseline is not None:
self.best = self.baseline
else:
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def on_epoch_end(self, steps, epoch, logs=None):
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
def on_train_end(self, logs=None):
if self.stopped_epoch > 0 and self.verbose > 0:
print(f'Epoch {self.stopped_epoch+1}: early stopping\n')
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
if monitor_value is None:
warnings.warn('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning)
return monitor_value
def metric_mapping(metric, y_pred, y_true):
if metric == 'accuracy':
if isinstance(y_pred, (list, tuple)):
y_pred = y_pred[0]
y_pred = torch.argmax(y_pred, dim=-1)
acc = torch.sum(y_pred.eq(y_true)).item() / y_true.size(0)
return acc
return None
def softmax(x, axis=-1):
"""numpy版softmax
"""
x = x - x.max(axis=axis, keepdims=True)
x = np.exp(x)
return x / x.sum(axis=axis, keepdims=True)
class AutoRegressiveDecoder(object):
"""通用自回归生成模型解码基类
包含beam search和random sample两种策略
"""
def __init__(self, start_id, end_id, maxlen, minlen=1, device='cpu'):
self.start_id = start_id
self.end_id = end_id
self.maxlen = maxlen
self.minlen = minlen
self.models = {}
self.device = device
if start_id is None:
self.first_output_ids = torch.empty((1, 0), dtype=int, device=device)
else:
self.first_output_ids = torch.tensor([[self.start_id]], device=device)
@staticmethod
def wraps(default_rtype='probas', use_states=False):
"""用来进一步完善predict函数
目前包含: 1. 设置rtype参数,并做相应处理;
2. 确定states的使用,并做相应处理;
3. 设置温度参数,并做相应处理。
"""
def actual_decorator(predict):
def new_predict(
self,
inputs,
output_ids,
states,
temperature=1,
rtype=default_rtype
):
assert rtype in ['probas', 'logits']
prediction = predict(self, inputs, output_ids, states)
if not use_states:
prediction = (prediction, None)
if default_rtype == 'logits':
prediction = (nn.Softmax(dim=-1)(prediction[0] / temperature), prediction[1])
elif temperature != 1:
probas = torch.power(prediction[0], 1.0 / temperature)
probas = probas / probas.sum(axis=-1, keepdims=True)
prediction = (probas, prediction[1])
if rtype == 'probas':
return prediction
else:
return torch.log(prediction[0] + 1e-12), prediction[1]
return new_predict
return actual_decorator
# def last_token(self, model):
# """创建一个只返回最后一个token输出的新Model
# """
# if model not in self.models:
# outputs = [
# keras.layers.Lambda(lambda x: x[:, -1])(output)
# for output in model.outputs
# ]
# self.models[model] = keras.models.Model(model.inputs, outputs)
# return self.models[model]
def predict(self, inputs, output_ids, states=None):
"""用户需自定义递归预测函数
说明: 定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states,
其中default_rtype为字符串logits或probas,probas时返回归一化的概率,
rtype=logits时则返回softmax前的结果或者概率对数。
返回: 二元组 (得分或概率, states)
"""
raise NotImplementedError
def beam_search(self, inputs_raw, topk, states=None, temperature=1, min_ends=1, add_btz_dim=True):
"""beam search解码
说明: 这里的topk即beam size;
返回: 最优解码序列。
"""
inputs = []
for i in inputs_raw:
if isinstance(i, torch.torch.Tensor):
pass
elif isinstance(i, (list, tuple, np.ndarray)) and add_btz_dim:
i = torch.tensor([i], device=self.device)
elif isinstance(i, (list, tuple, np.ndarray)) and not add_btz_dim:
i = torch.tensor(i, device=self.device)
else:
raise ValueError('Beam search inputs ele only support tensor、array、list、tuple')
inputs.append(i)
output_ids, output_scores = self.first_output_ids, torch.zeros(1, device=self.device)
for step in range(self.maxlen):
scores, states = self.predict(inputs, output_ids, states, temperature, 'logits') # 计算当前得分
if step == 0: # 第1步预测后将输入重复topk次
inputs = [i.repeat([topk]+[1]*(len(i.shape)-1)) for i in inputs]
scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分
indices = scores.flatten().argsort(dim=-1, descending=True)[:topk] # 仅保留topk
indices_1 = torch.div(indices, scores.shape[1], rounding_mode='trunc') # 行索引
indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引
output_ids = torch.cat([output_ids[indices_1], indices_2], 1) # 更新输出
output_scores = take_along_dim(scores, indices, dim=None) # 更新得分
is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束
end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
if output_ids.shape[1] >= self.minlen: # 最短长度判断
best = output_scores.argmax() # 得分最大的那个
if is_end[best] and end_counts[best] >= min_ends: # 如果已经终止
return output_ids[best] # 直接输出
else: # 否则,只保留未完成部分
flag = ~is_end | (end_counts < min_ends) # 标记未完成序列
if not flag.all(): # 如果有已完成的
inputs = [i[flag] for i in inputs] # 扔掉已完成序列
output_ids = output_ids[flag] # 扔掉已完成序列
output_scores = output_scores[flag] # 扔掉已完成序列
end_counts = end_counts[flag] # 扔掉已完成end计数
topk = flag.sum() # topk相应变化
# 达到长度直接输出
return output_ids[output_scores.argmax()]
def random_sample(
self,
inputs,
n,
topk=None,
topp=None,
states=None,
temperature=1,
min_ends=1
):
"""随机采样n个结果
说明: 非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
返回: n个解码序列组成的list。
"""
inputs = [torch.tensor([i], device=self.device) for i in inputs]
output_ids = self.first_output_ids
results = []
for step in range(self.maxlen):
probas, states = self.predict(inputs, output_ids, states, temperature, 'probas') # 计算当前概率
probas /= probas.sum(dim=-1, keepdims=True) # 确保归一化
if step == 0: # 第1步预测后将结果重复n次
probas = probas.repeat([n]+[1]*(len(probas.shape)-1))
inputs = [i.repeat([n]+[1]*(len(i.shape)-1)) for i in inputs]
output_ids = output_ids.repeat([n]+[1]*(len(output_ids.shape)-1))
if topk is not None:
k_indices = probas.argsort(dim=-1, descending=True)[:, :topk] # 仅保留topk
probas = take_along_dim(probas, k_indices, dim=1) # topk概率
probas /= probas.sum(dim=1, keepdims=True) # 重新归一化
if topp is not None:
p_indices = probas.argsort(dim=-1, descending=True) # 从高到低排序
probas = take_along_dim(probas, p_indices, dim=-1) # 排序概率
cumsum_probas = torch.cumsum(probas, dim=-1) # 累积概率
flag = torch.roll(cumsum_probas >= topp, 1, dims=1) # 标记超过topp的部分
flag[:, 0] = False # 结合上面的torch.roll,实现平移一位的效果
probas[flag] = 0 # 后面的全部置零
probas /= probas.sum(dim=1, keepdims=True) # 重新归一化
sample_func = lambda p: torch.multinomial(p, 1) # 按概率采样函数
sample_ids = torch.stack([sample_func(p) for p in probas])
sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状
if topp is not None:
sample_ids = take_along_dim(p_indices, sample_ids, dim=1) # 对齐原id
if topk is not None:
sample_ids = take_along_dim(k_indices, sample_ids, dim=1) # 对齐原id
output_ids = torch.cat([output_ids, sample_ids], 1) # 更新输出
is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束
end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记
if output_ids.shape[1] >= self.minlen: # 最短长度判断
flag = is_end & (end_counts >= min_ends) # 标记已完成序列
if flag.any(): # 如果有已完成的
for ids in output_ids[flag]: # 存好已完成序列
results.append(ids)
flag = (flag == False) # 标记未完成序列
inputs = [i[flag] for i in inputs] # 只保留未完成部分输入
output_ids = output_ids[flag] # 只保留未完成部分候选集
end_counts = end_counts[flag] # 只保留未完成部分end计数
if len(output_ids) == 0:
break
# 如果还有未完成序列,直接放入结果
for ids in output_ids:
results.append(ids)
# 返回结果
return results
def search_layer(model, layer_name, retrun_first=True):
return_list = []
for name, param in model.named_parameters():
if param.requires_grad and layer_name in name:
return_list.append(param)
if len(return_list) == 0:
return None
if retrun_first:
return return_list[0]
else:
return return_list
class ListDataset(Dataset):
def __init__(self, file_path=None, data=None, **kwargs):
self.kwargs = kwargs
if isinstance(file_path, (str, list)):
self.data = self.load_data(file_path)
elif isinstance(data, list):
self.data = data
else:
raise ValueError('The input args shall be str format file_path / list format dataset')
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
@staticmethod
def load_data(file_path):
return file_path
class IterDataset(IterableDataset):
'''流式读取文件
'''
def __init__(self, file_path=None, **kwargs):
self.kwargs = kwargs
if isinstance(file_path, (str, list)):
self.file_path = file_path
else:
raise ValueError('The input args shall be str format file_path / list format dataset')
def __iter__(self):
return self.load_data(self.file_path)
@staticmethod
def load_data(file_path):
return file_path
# sinusoid编码
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
'''Returns: [seq_len, d_hid]
'''
position = torch.arange(0, n_position, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_hid, 2).float() * (-math.log(10000.0) / d_hid))
embeddings_table = torch.zeros(n_position, d_hid)
embeddings_table[:, 0::2] = torch.sin(position * div_term)
embeddings_table[:, 1::2] = torch.cos(position * div_term)
return embeddings_table
# 第二种实现
position_ids = torch.arange(0, n_position).unsqueeze(1)
position_ids = position_ids.expand(-1, d_hid)
indices = torch.arange(0, d_hid)
position_ids = position_ids * torch.pow(10000, -2 * torch.true_divide(torch.floor_divide(indices, 2), d_hid))
position_ids[:, ::2] = torch.sin(position_ids[:, ::2])
position_ids[:, 1::2] = torch.cos(position_ids[:, 1::2])
return position_ids
def cal_ts_num(tensor_shape):
'''查看某个tensor在gc中的数量
'''
cal_num = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj): # or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor = obj
else:
continue
if tensor.is_cuda and tensor.size() == tensor_shape:
print(tensor.shape)
cal_num+=1
except Exception as e:
print('A trivial exception occured: {}'.format(e))
print(cal_num)
def get_kw(cls, kwargs):
'''保留排除cls的入参后的kwargs
'''
kwargs_new = {}
for k in kwargs:
if k not in set(inspect.getargspec(cls)[0]):
kwargs_new[k] = kwargs[k]
return kwargs_new
class FGM():
'''对抗训练
'''
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='word_embeddings', **kwargs):
# emb_name这个参数要换成你模型中embedding的参数名
# 例如,self.emb = nn.Embedding(5000, 100)
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad) # 默认为2范数
if norm != 0 and not torch.isnan(norm): # nan是为了apex混合精度时:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='emb', **kwargs):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
class PGD():
'''对抗训练
'''
def __init__(self, model):
self.model = model
self.emb_backup = {}
self.grad_backup = {}
def attack(self, epsilon=1., alpha=0.3, emb_name='word_embeddings', is_first_attack=False, **kwargs):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
if is_first_attack:
self.emb_backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0 and not torch.isnan(norm): # nan是为了apex混合精度时
r_at = alpha * param.grad / norm
param.data.add_(r_at)
param.data = self.project(name, param.data, epsilon)
def restore(self, emb_name='emb', **kwargs):
# emb_name这个参数要换成你模型中embedding的参数名
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.emb_backup
param.data = self.emb_backup[name]
self.emb_backup = {}
def project(self, param_name, param_data, epsilon):
r = param_data - self.emb_backup[param_name]
if torch.norm(r) > epsilon:
r = epsilon * r / torch.norm(r)
return self.emb_backup[param_name] + r
def backup_grad(self):
for name, param in self.model.named_parameters():
# 修复如pooling层参与foward,但是不参与backward过程时grad为空的问题
if param.requires_grad and (param.grad is not None):
self.grad_backup[name] = param.grad.clone()
def restore_grad(self):
for name, param in self.model.named_parameters():
if param.requires_grad and (param.grad is not None):
param.grad = self.grad_backup[name]
class VAT():
'''虚拟对抗训练 https://github.com/namisan/mt-dnn/blob/v0.2/alum/adv_masked_lm.py
'''
def __init__(self, model, emb_name='word_embeddings', noise_var=1e-5, noise_gamma=1e-6, adv_step_size=1e-3,
adv_alpha=1, norm_type='l2', **kwargs):
self.model = model
self.noise_var = noise_var # 噪声的方差
self.noise_gamma = noise_gamma # eps
self.adv_step_size = adv_step_size # 学习率
self.adv_alpha = adv_alpha # 对抗loss的权重
self.norm_type = norm_type # 归一化方式
self.embed = None
for (name, module) in self.model.named_modules():
if emb_name in name:
module.register_forward_hook(hook=self.hook)
def hook(self, module, fea_in, fea_out):
self.embed = fea_out
return None
def forward_(self, train_X, new_embed):
# 把原来的train_X中的token_ids换成embedding形式
if isinstance(train_X, (tuple, list)):
new_train_X = [new_embed] + train_X[1:]
adv_output = self.model.forward(*new_train_X) if self.model.forward.__code__.co_argcount >= 3 else self.model.forward(new_train_X)
elif isinstance(train_X, torch.Tensor):
adv_output = self.model.forward(new_embed)
return adv_output
def virtual_adversarial_training(self, train_X, logits):
# 初始扰动 r
noise = self.embed.data.new(self.embed.size()).normal_(0, 1) * self.noise_var
noise.requires_grad_()
# x + r
new_embed = self.embed.data.detach() + noise
adv_output = self.forward_(train_X, new_embed) # forward第一次
adv_logits = adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output
adv_loss = self.kl(adv_logits, logits.detach(), reduction="batchmean")
delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)
norm = delta_grad.norm()
# 梯度消失,退出
if torch.isnan(norm) or torch.isinf(norm):
return None
# inner sum
noise = noise + delta_grad * self.adv_step_size
# projection
noise = self.adv_project(noise, norm_type=self.norm_type, eps=self.noise_gamma)
new_embed = self.embed.data.detach() + noise
new_embed = new_embed.detach()
# 在进行一次训练
adv_output = self.forward_(train_X, new_embed) # forward第二次
adv_logits = adv_output[0] if isinstance(adv_output, (list, tuple)) else adv_output
adv_loss_f = self.kl(adv_logits, logits.detach())
adv_loss_b = self.kl(logits, adv_logits.detach())
# 在预训练时设置为10,下游任务设置为1
adv_loss = (adv_loss_f + adv_loss_b) * self.adv_alpha
return adv_loss
@staticmethod
def kl(inputs, targets, reduction="sum"):
"""
计算kl散度
inputs:tensor,logits
targets:tensor,logits
"""
loss = F.kl_div(F.log_softmax(inputs, dim=-1), F.softmax(targets, dim=-1), reduction=reduction)
return loss
@staticmethod
def adv_project(grad, norm_type='inf', eps=1e-6):
"""
L0,L1,L2正则,对于扰动计算
"""
if norm_type == 'l2':
direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
elif norm_type == 'l1':
direction = grad.sign()
else:
direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
return direction
class WebServing(object):
"""简单的Web接口
用法:
arguments = {'text': (None, True), 'n': (int, False)}
web = WebServing(port=8864)
web.route('/gen_synonyms', gen_synonyms, arguments)
web.start()
# 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好
说明:
基于bottlepy简单封装,仅作为临时测试使用,不保证性能。
目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1。
欢迎有经验的开发者帮忙改进。
依赖:
pip install bottle
pip install paste
(如果不用 server='paste' 的话,可以不装paste库)
"""
def __init__(self, host='0.0.0.0', port=8000, server='paste'):
import bottle
self.host = host
self.port = port
self.server = server
self.bottle = bottle
def wraps(self, func, arguments, method='GET'):
"""封装为接口函数
参数:
func:要转换为接口的函数,需要保证输出可以json化,即需要
保证 json.dumps(func(inputs)) 能被执行成功;
arguments:声明func所需参数,其中key为参数名,value[0]为
对应的转换函数(接口获取到的参数值都是字符串
型),value[1]为该参数是否必须;
method:GET或者POST。
"""
def new_func():
outputs = {'code': 0, 'desc': u'succeeded', 'data': {}}
kwargs = {}
for key, value in arguments.items():
if method == 'GET':
result = self.bottle.request.GET.getunicode(key)
else:
result = self.bottle.request.POST.getunicode(key)
if result is None:
if value[1]:
outputs['code'] = 1
outputs['desc'] = 'lack of "%s" argument' % key
return json.dumps(outputs, ensure_ascii=False)
else:
if value[0] is not None:
result = value[0](result)
kwargs[key] = result
try:
outputs['data'] = func(**kwargs)
except Exception as e:
outputs['code'] = 2
outputs['desc'] = str(e)
return json.dumps(outputs, ensure_ascii=False)
return new_func
def route(self, path, func, arguments, method='GET'):
"""添加接口
"""
func = self.wraps(func, arguments, method)
self.bottle.route(path, method=method)(func)
def start(self):
"""启动服务
"""
self.bottle.run(host=self.host, port=self.port, server=self.server)
def get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None):
''' 获取句向量
'''
if pool_strategy == 'pooler':
return pooler
elif pool_strategy == 'cls':
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} strategy request tensor hidden_state'
return hidden_state[:, 0]
elif pool_strategy in {'last-avg', 'mean'}:
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
hid = torch.sum(hidden_state * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / attention_mask
elif pool_strategy in {'last-max', 'max'}:
if isinstance(hidden_state, (list, tuple)):
hidden_state = hidden_state[-1]
assert isinstance(hidden_state, torch.Tensor), f'{pool_strategy} pooling strategy request tensor hidden_state'
hid = hidden_state * attention_mask[:, :, None]
return torch.max(hid, dim=1)
elif pool_strategy == 'first-last-avg':
assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
hid = torch.sum(hidden_state[1] * attention_mask[:, :, None], dim=1) # 这里不取0
hid += torch.sum(hidden_state[-1] * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / (2 * attention_mask)
elif pool_strategy == 'custom':
# 取指定层
assert isinstance(hidden_state, list), f'{pool_strategy} pooling strategy request list hidden_state'
assert isinstance(custom_layer, (int, list, tuple)), f'{pool_strategy} pooling strategy request int/list/tuple custom_layer'
custom_layer = [custom_layer] if isinstance(custom_layer, int) else custom_layer
hid = 0
for i, layer in enumerate(custom_layer, start=1):
hid += torch.sum(hidden_state[layer] * attention_mask[:, :, None], dim=1)
attention_mask = torch.sum(attention_mask, dim=1)[:, None]
return hid / (i * attention_mask)
else:
raise ValueError('pool_strategy illegal')
def seed_everything(seed=None):
'''固定seed
'''
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min
if (seed is None) or not (min_seed_value <= seed <= max_seed_value):
random.randint(np.iinfo(np.uint32).min, np.iinfo(np.uint32).max)
print(f"Global seed set to {seed}")
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
\ No newline at end of file
# coding=utf-8
"""Tokenization classes."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections
import logging
import unicodedata
from io import open
from bert4torch.snippets import truncate_sequences, is_string, lowercase_and_normalize
import re
import six
from collections import OrderedDict
logger = logging.getLogger(__name__)
is_py2 = six.PY2
def load_vocab(dict_path, encoding="utf-8", simplified=False, startswith=None):
"""加载词典文件到dict"""
token_dict = collections.OrderedDict()
index = 0
with open(dict_path, "r", encoding=encoding) as reader:
while True:
token = reader.readline()
if not token:
break
token = token.strip()
token_dict[token] = index
index += 1
if simplified: # 过滤冗余部分token,如[unused1]
new_token_dict, keep_tokens = {}, []
startswith = startswith or []
for t in startswith:
new_token_dict[t] = len(new_token_dict)
keep_tokens.append(token_dict[t])
for t, _ in sorted(token_dict.items(), key=lambda s: s[1]):
if t not in new_token_dict and not Tokenizer._is_redundant(t):
new_token_dict[t] = len(new_token_dict)
keep_tokens.append(token_dict[t])
return new_token_dict, keep_tokens
else:
return token_dict
def whitespace_tokenize(text):
"""去除文本中的空白符"""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class TokenizerBase(object):
"""分词器基类
"""
def __init__(self, token_start='[CLS]', token_end='[SEP]', token_unk='[UNK]', token_pad='[PAD]', token_mask='[MASK]',
add_special_tokens=None, pre_tokenize=None, token_translate=None):
"""参数说明:
token_unk: 未知词标记
token_end: 句子切分标记,当只有一句话作为输入时,此标记知识作为结束符;当有两句话作为输入时,此标记作为分隔符、最后一句话的结束符
pad_token: padding填充标记
token_start: 分类标记,位于整个序列的第一个
mask_token: mask标记
pre_tokenize: 外部传入的分词函数,用作对文本进行预分词。如果传入pre_tokenize,则先执行pre_tokenize(text),然后在它的基础上执行原本的tokenize函数;
token_translate: 映射字典,主要用在tokenize之后,将某些特殊的token替换为对应的token。
"""
self._token_pad = token_pad
self._token_unk = token_unk
self._token_mask = token_mask
self._token_start = token_start
self._token_end = token_end
self.never_split = [self._token_unk, self._token_end, self._token_pad, self._token_start, self._token_mask]
if add_special_tokens is not None:
if isinstance(add_special_tokens, (tuple, list)):
self.never_split.extend(add_special_tokens)
elif isinstance(add_special_tokens, str):
self.never_split.append(add_special_tokens)
self.tokens_trie = self._create_trie(self.never_split) # trie树主要是为了special_tokens的分词
self._pre_tokenize = pre_tokenize
self._token_translate = token_translate or {}
self._token_translate_inv = {v: k for k, v in self._token_translate.items()}
def _create_trie(self, unique_no_split_tokens):
trie = Trie()
for token in unique_no_split_tokens:
trie.add(token)
return trie
def tokenize(self, text, maxlen=None):
"""分词函数
"""
tokens = [self._token_translate.get(token) or token for token in self._tokenize(text)]
if self._token_start is not None:
tokens.insert(0, self._token_start)
if self._token_end is not None:
tokens.append(self._token_end)
if maxlen is not None:
index = int(self._token_end is not None) + 1
truncate_sequences(maxlen, -index, tokens)
return tokens
def token_to_id(self, token):
"""token转换为对应的id
"""
raise NotImplementedError
def tokens_to_ids(self, tokens):
"""token序列转换为对应的id序列
"""
return [self.token_to_id(token) for token in tokens]
def _encode(self, first_text, second_text=None, maxlen=None, pattern='S*E*E', truncate_from='right', return_offsets=False):
"""输出文本对应token id和segment id
"""
first_tokens = self.tokenize(first_text) if is_string(first_text) else first_text
if second_text is None:
second_tokens = None
elif is_string(second_text):
second_tokens = self.tokenize(second_text)
else:
second_tokens = second_text
if maxlen is not None:
# 这里截断思路是优先截断最长的子句
if truncate_from == 'right':
index = -int(self._token_end is not None) - 1
elif truncate_from == 'left':
index = int(self._token_start is not None)
else:
index = truncate_from
if second_text is not None and pattern == 'S*E*E':
maxlen += 1
truncate_sequences(maxlen, index, first_tokens, second_tokens)
first_token_ids = self.tokens_to_ids(first_tokens)
first_segment_ids = [0] * len(first_token_ids)
if second_text is not None:
if pattern == 'S*E*E':
idx = int(bool(self._token_start))
second_tokens = second_tokens[idx:]
second_token_ids = self.tokens_to_ids(second_tokens)
second_segment_ids = [1] * len(second_token_ids)
first_token_ids.extend(second_token_ids)
first_segment_ids.extend(second_segment_ids)
encode_output = [first_token_ids, first_segment_ids]
if return_offsets != False:
offset = self.rematch(first_text, first_tokens) + self.rematch(second_text, second_tokens)
if return_offsets == 'transformers': # transformers包中tokenizer的形式
encode_output.append([[0, 0] if not k else [k[0], k[-1]+1] for k in offset])
else:
encode_output.append(offset)
return encode_output
def encode(self, first_texts, second_texts=None, maxlen=None, pattern='S*E*E', truncate_from='right', return_offsets=False):
'''可以处理多条或者单条
'''
return_list = False if isinstance(first_texts, str) else True
first_texts = [first_texts] if isinstance(first_texts, str) else first_texts
second_texts = [second_texts] if isinstance(second_texts, str) else second_texts
first_token_ids, first_segment_ids, offsets = [], [], []
if second_texts is None:
second_texts = [None] * len(first_texts)
assert len(first_texts) == len(second_texts), 'first_texts and second_texts should be same length'
# 循环处理每条样本
for first_text, second_text in zip(first_texts, second_texts):
outputs = self._encode(first_text, second_text, maxlen, pattern, truncate_from, return_offsets)
first_token_ids.append(outputs[0])
first_segment_ids.append(outputs[1])
if len(outputs) >= 3:
offsets.append(outputs[2])
encode_outputs = [first_token_ids, first_segment_ids]
if return_offsets:
encode_outputs.append(offsets)
if not return_list: # 如果输入是string
encode_outputs = [item[0] for item in encode_outputs]
return encode_outputs
def id_to_token(self, i):
"""id序列为对应的token
"""
raise NotImplementedError
def ids_to_tokens(self, ids):
"""id序列转换为对应的token序列
"""
return [self.id_to_token(i) for i in ids]
def decode(self, ids):
"""转为可读文本
"""
raise NotImplementedError
def _tokenize(self, text):
"""基本分词函数
"""
raise NotImplementedError
def rematch(self):
"""生成text和tokens之间的对应关系
"""
pass
class Tokenizer(TokenizerBase):
"""Bert原生分词器
"""
def __init__(self, token_dict, do_lower_case=True, do_basic_tokenize=True, do_tokenize_unk=False, **kwargs):
"""
参数:
token_dict:
词典文件
do_lower_case:
是否转换成小写
do_basic_tokenize:
分词前,是否进行基础的分词
do_tokenize_unk:
分词后,是否生成[UNK]标记,还是在encode阶段生成
"""
super(Tokenizer, self).__init__(**kwargs)
if is_string(token_dict):
token_dict = load_vocab(token_dict)
self._do_lower_case = do_lower_case
self._vocab_size = len(token_dict)
self._token_dict = token_dict
self._token_dict_inv = {v: k for k, v in token_dict.items()}
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, never_split=self.never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self._token_dict, unk_token=self._token_unk, do_tokenize_unk=do_tokenize_unk)
for token in ['pad', 'unk', 'mask', 'start', 'end']:
try:
_token_id = token_dict[getattr(self, '_token_%s' % token)]
setattr(self, '_token_%s_id' % token, _token_id)
except:
pass
def _tokenize(self, text, pre_tokenize=True):
"""基本分词函数
"""
# 以下pre_tokenizer逻辑参考bert4keras
if self._do_lower_case:
text = lowercase_and_normalize(text, never_split=self.never_split)
if pre_tokenize and self._pre_tokenize is not None:
tokens = []
for token in self._pre_tokenize(text):
if token in self._token_dict:
tokens.append(token)
else:
tokens.extend(self._tokenize(token, False))
return tokens
# 以下逻辑参考pytorch版本bert分词器自己的
text_pieces = self.tokens_trie.split(text) # 新增逻辑,主要是special_tokens的分词
split_tokens = []
for text_piece in text_pieces:
if not text_piece:
continue
elif text_piece in self._token_dict:
split_tokens.append(text_piece)
elif self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text_piece):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens.extend(self.wordpiece_tokenizer.tokenize(text_piece))
return split_tokens
def token_to_id(self, token):
"""token转为vocab中的id"""
return self._token_dict.get(token, self._token_unk_id)
def id_to_token(self, id):
"""id转为词表中的token"""
return self._token_dict_inv[id]
def decode(self, ids, tokens=None):
"""转为可读文本
"""
tokens = tokens or self.ids_to_tokens(ids)
tokens = [token for token in tokens if not self._is_special(token)]
text, flag = '', False
for i, token in enumerate(tokens):
if token[:2] == '##':
text += token[2:]
elif len(token) == 1 and self._is_cjk_character(token):
text += token
elif len(token) == 1 and self._is_punctuation(token):
text += token
text += ' '
elif i > 0 and self._is_cjk_character(text[-1]):
text += token
else:
text += ' '
text += token
text = re.sub(' +', ' ', text)
text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
punctuation = self._cjk_punctuation() + '+-/={(<['
punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
punctuation_regex = '(%s) ' % punctuation_regex
text = re.sub(punctuation_regex, '\\1', text)
text = re.sub('(\d\.) (\d)', '\\1\\2', text)
return text.strip()
@staticmethod
def stem(token):
"""获取token的“词干”(如果是##开头,则自动去掉##)
"""
if token[:2] == '##':
return token[2:]
else:
return token
@staticmethod
def _is_space(ch):
"""空格类字符判断
"""
return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
unicodedata.category(ch) == 'Zs'
@staticmethod
def _is_punctuation(ch):
"""标点符号类字符判断(全/半角均在此内)
提醒:unicodedata.category这个函数在py2和py3下的
表现可能不一样,比如u'§'字符,在py2下的结果为'So',
在py3下的结果是'Po'。
"""
code = ord(ch)
return 33 <= code <= 47 or \
58 <= code <= 64 or \
91 <= code <= 96 or \
123 <= code <= 126 or \
unicodedata.category(ch).startswith('P')
@staticmethod
def _cjk_punctuation():
return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
@staticmethod
def _is_cjk_character(ch):
"""CJK类字符判断(包括中文字符也在此列)
参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
"""
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF or \
0x3400 <= code <= 0x4DBF or \
0x20000 <= code <= 0x2A6DF or \
0x2A700 <= code <= 0x2B73F or \
0x2B740 <= code <= 0x2B81F or \
0x2B820 <= code <= 0x2CEAF or \
0xF900 <= code <= 0xFAFF or \
0x2F800 <= code <= 0x2FA1F
@staticmethod
def _is_control(ch):
"""控制类字符判断
"""
return unicodedata.category(ch) in ('Cc', 'Cf')
@staticmethod
def _is_special(ch):
"""判断是不是有特殊含义的符号
"""
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
@staticmethod
def _is_redundant(token):
"""判断该token是否冗余(默认情况下不可能分出来)
"""
if len(token) > 1:
for ch in Tokenizer.stem(token):
if (
Tokenizer._is_cjk_character(ch) or
Tokenizer._is_punctuation(ch)
):
return True
def rematch(self, text, tokens):
"""给出原始的text和tokenize后的tokens的映射关系
"""
if is_py2:
text = unicode(text)
if self._do_lower_case:
text = text.lower()
normalized_text, char_mapping = '', []
for i, ch in enumerate(text):
if self._do_lower_case:
ch = lowercase_and_normalize(ch, self.never_split)
ch = ''.join([
c for c in ch
if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c))
])
normalized_text += ch
char_mapping.extend([i] * len(ch))
text, token_mapping, offset = normalized_text, [], 0
for token in tokens:
if self._is_special(token):
token_mapping.append([])
else:
token = self.stem(token)
start = text[offset:].index(token) + offset
end = start + len(token)
token_mapping.append(char_mapping[start:end])
offset = end
return token_mapping
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
self.never_split = never_split
def tokenize(self, text):
"""文本切分成token"""
text = self._clean_text(text)
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in self.never_split:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
if text in self.never_split:
return [text]
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenization."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100, do_tokenize_unk=False):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
self.do_tokenize_unk = do_tokenize_unk
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token if self.do_tokenize_unk else token) # 超长
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if (substr in self.vocab) or (not self.do_tokenize_unk):
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if self.do_tokenize_unk and is_bad: # 是否在tokenize阶段转UNK
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat.startswith("C"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
class SpTokenizer(TokenizerBase):
"""基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。
"""
def __init__(self, sp_model_path, **kwargs):
super(SpTokenizer, self).__init__(**kwargs)
import sentencepiece as spm
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(sp_model_path)
self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id())
self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id())
self._vocab_size = self.sp_model.get_piece_size()
for token in ['pad', 'unk', 'mask', 'start', 'end']:
try:
_token = getattr(self, '_token_%s' % token)
_token_id = self.sp_model.piece_to_id(_token)
setattr(self, '_token_%s_id' % token, _token_id)
except:
pass
def token_to_id(self, token):
"""token转换为对应的id
"""
return self.sp_model.piece_to_id(token)
def id_to_token(self, i):
"""id转换为对应的token
"""
if i < self._vocab_size:
return self.sp_model.id_to_piece(i)
else:
return ''
def decode(self, ids):
"""转为可读文本
"""
tokens = [self._token_translate_inv.get(token) or token for token in self.ids_to_tokens(ids)]
text = self.sp_model.decode_pieces(tokens)
return convert_to_unicode(text)
def _tokenize(self, text):
"""基本分词函数
"""
if self._pre_tokenize is not None:
text = ' '.join(self._pre_tokenize(text))
tokens = self.sp_model.encode_as_pieces(text)
return tokens
def _is_special(self, i):
"""判断是不是有特殊含义的符号
"""
return self.sp_model.is_control(i) or \
self.sp_model.is_unknown(i) or \
self.sp_model.is_unused(i)
def _is_decodable(self, i):
"""判断是否应该被解码输出
"""
return (i < self._vocab_size) and not self._is_special(i)
class Trie:
"""直接从transformer的tokenization_utils.py中移植, 主要是为了special_tokens分词
"""
def __init__(self):
self.data = {}
def add(self, word: str):
if not word:
# Prevent empty string
return
ref = self.data
for char in word:
ref[char] = char in ref and ref[char] or {}
ref = ref[char]
ref[""] = 1
def split(self, text: str):
states = OrderedDict()
# This will contain every indices where we need
# to cut.
# We force to cut at offset 0 and len(text) (added later)
offsets = [0]
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip = 0
# Main loop, Giving this algorithm O(n) complexity
for current, current_char in enumerate(text):
if skip and current < skip:
# Prevents the lookahead for matching twice
# like extra_id_100 and id_100
continue
# This will track every state
# that stop matching, we need to stop tracking them.
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
# fail on "b", we need to remove 0 from the valid states.
to_remove = set()
# Whenever we found a match, we need to drop everything
# this is a greedy algorithm, it will match on the first found token
reset = False
# In this case, we already have partial matches (But unfinished)
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Here we are also actively looking for other earlier partial
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for lookstart, looktrie_pointer in states.items():
if lookstart > start:
# This partial match is later, we can stop looking
break
elif lookstart < start:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index = current + 1
end = current + 1
else:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index = current
end = current
next_char = text[lookahead_index] if lookahead_index < len(text) else None
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
while next_char in looktrie_pointer:
looktrie_pointer = looktrie_pointer[next_char]
lookahead_index += 1
if "" in looktrie_pointer:
start = lookstart
end = lookahead_index
skip = lookahead_index
if lookahead_index == len(text):
# End of string
break
next_char = text[lookahead_index]
# End lookahead
# Storing and resetting
offsets.append(start)
offsets.append(end)
reset = True
break
elif current_char in trie_pointer:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer = trie_pointer[current_char]
# Storing back the new pointer into the states.
# Partial matches got longer by one.
states[start] = trie_pointer
else:
# The new character has not match in the trie, we need
# to stop keeping track of this partial match.
# We can't do it directly within the loop because of how
# python iteration works
to_remove.add(start)
# Either clearing the full start (we found a real match)
# Or clearing only the partial matches that didn't work.
if reset:
states = {}
else:
for start in to_remove:
del states[start]
# If this character is a starting character within the trie
# start keeping track of this partial match.
if current >= skip and current_char in self.data:
states[current] = self.data[current_char]
# We have a cut at the end with states.
for start, trie_pointer in states.items():
if "" in trie_pointer:
# This is a final match, we need to reset and
# store the results in `offsets`.
end = len(text)
offsets.append(start)
offsets.append(end)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
return self.cut_text(text, offsets)
def cut_text(self, text, offsets):
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
offsets.append(len(text))
tokens = []
start = 0
for end in offsets:
if start > end:
logger.error(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif start == end:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens.append(text[start:end])
start = end
return tokens
#! -*- coding: utf-8 -*-
__version__ = '0.1.9'
\ No newline at end of file
# 从transformer中移植过来的activation, 原来的bert4keras并没有
import math
import torch
from packaging import version
from torch import nn
def _gelu_python(x):
"""
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = nn.functional.gelu
def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = nn.functional.silu
def _mish_python(x):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x))
if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def linear_act(x):
return x
ACT2FN = {
"relu": nn.functional.relu,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,
"gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
"mish": mish,
"linear": linear_act,
"sigmoid": torch.sigmoid,
"softmax": nn.Softmax(dim=-1)
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
import torch
from torch.functional import Tensor
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from bert4torch.snippets import get_sinusoid_encoding_table, take_along_dim
from bert4torch.activations import get_activation
from typing import List, Optional
import random
import warnings
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12, conditional_size=False, weight=True, bias=True, norm_mode='normal', **kwargs):
"""layernorm 层,这里自行实现,目的是为了兼容 conditianal layernorm,使得可以做条件文本生成、条件分类等任务
条件layernorm来自于苏剑林的想法,详情:https://spaces.ac.cn/archives/7124
"""
super(LayerNorm, self).__init__()
# 兼容roformer_v2不包含weight
if weight:
self.weight = nn.Parameter(torch.ones(hidden_size))
# 兼容t5不包含bias项, 和t5使用的RMSnorm
if bias:
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.norm_mode = norm_mode
self.eps = eps
self.conditional_size = conditional_size
if conditional_size:
# 条件layernorm, 用于条件文本生成,
# 这里采用全零初始化, 目的是在初始状态不干扰原来的预训练权重
self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False)
self.dense1.weight.data.uniform_(0, 0)
self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False)
self.dense2.weight.data.uniform_(0, 0)
def forward(self, x):
inputs = x[0]
if self.norm_mode == 'rmsnorm':
# t5使用的是RMSnorm
variance = inputs.to(torch.float32).pow(2).mean(-1, keepdim=True)
o = inputs * torch.rsqrt(variance + self.eps)
else:
u = inputs.mean(-1, keepdim=True)
s = (inputs - u).pow(2).mean(-1, keepdim=True)
o = (inputs - u) / torch.sqrt(s + self.eps)
if not hasattr(self, 'weight'):
self.weight = 1
if not hasattr(self, 'bias'):
self.bias = 0
if self.conditional_size:
cond = x[1]
for _ in range(len(inputs.shape) - len(cond.shape)):
cond = cond.unsqueeze(dim=1)
return (self.weight + self.dense1(cond)) * o + (self.bias + self.dense2(cond))
else:
return self.weight * o + self.bias
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, attention_scale=True,
return_attention_scores=False, bias=True, **kwargs):
super(MultiHeadAttentionLayer, self).__init__()
assert hidden_size % num_attention_heads == 0
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.attention_scale = attention_scale
self.return_attention_scores = return_attention_scores
self.bias = bias
self.q = nn.Linear(hidden_size, hidden_size, bias=bias)
self.k = nn.Linear(hidden_size, hidden_size, bias=bias)
self.v = nn.Linear(hidden_size, hidden_size, bias=bias)
self.o = nn.Linear(hidden_size, hidden_size, bias=bias)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.a_bias, self.p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
if self.p_bias == 'typical_relative': # nezha
self.relative_positions_encoding = RelativePositionsEncoding(qlen=kwargs.get('max_position'),
klen=kwargs.get('max_position'),
embedding_size=self.attention_head_size,
max_relative_position=kwargs.get('max_relative_position'))
elif self.p_bias == 'rotary': # roformer
self.relative_positions_encoding = RoPEPositionEncoding(max_position=kwargs.get('max_position'), embedding_size=self.attention_head_size)
elif self.p_bias == 't5_relative': # t5
self.relative_positions = RelativePositionsEncodingT5(qlen=kwargs.get('max_position'),
klen=kwargs.get('max_position'),
relative_attention_num_buckets=kwargs.get('relative_attention_num_buckets'),
is_decoder=kwargs.get('is_decoder'))
self.relative_positions_encoding = nn.Embedding(kwargs.get('relative_attention_num_buckets'), self.num_attention_heads)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
# hidden_states shape: [batch_size, seq_q, hidden_size]
# attention_mask shape: [batch_size, 1, 1, seq_q] 或者 [batch_size, 1, seq_q, seq_q]
# encoder_hidden_states shape: [batch_size, seq_k, hidden_size]
# encoder_attention_mask shape: [batch_size, 1, 1, seq_k]
mixed_query_layer = self.q(hidden_states)
if encoder_hidden_states is not None:
mixed_key_layer = self.k(encoder_hidden_states)
mixed_value_layer = self.v(encoder_hidden_states)
attention_mask = encoder_attention_mask
else:
mixed_key_layer = self.k(hidden_states)
mixed_value_layer = self.v(hidden_states)
# mixed_query_layer shape: [batch_size, query_len, hidden_size]
# mixed_query_layer shape: [batch_size, key_len, hidden_size]
# mixed_query_layer shape: [batch_size, value_len, hidden_size]
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# query_layer shape: [batch_size, num_attention_heads, query_len, attention_head_size]
# key_layer shape: [batch_size, num_attention_heads, key_len, attention_head_size]
# value_layer shape: [batch_size, num_attention_heads, value_len, attention_head_size]
if self.p_bias == 'rotary':
query_layer = self.relative_positions_encoding(query_layer)
key_layer = self.relative_positions_encoding(key_layer)
# 交换k的最后两个维度,然后q和k执行点积, 获得attention score
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
# attention_scores shape: [batch_size, num_attention_heads, query_len, key_len]
if (self.p_bias == 'typical_relative') and hasattr(self, 'relative_positions_encoding'):
relations_keys = self.relative_positions_encoding(attention_scores.shape[-1], attention_scores.shape[-1]) # [to_seq_len, to_seq_len, d_hid]
# 旧实现,方便读者理解维度转换
# query_layer_t = query_layer.permute(2, 0, 1, 3)
# query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, self.attention_head_size)
# key_position_scores = torch.matmul(query_layer_r, relations_keys.permute(0, 2, 1))
# key_position_scores_r = key_position_scores.view(from_seq_length, batch_size, num_attention_heads, from_seq_length)
# key_position_scores_r_t = key_position_scores_r.permute(1, 2, 0, 3)
# 新实现
key_position_scores_r_t = torch.einsum('bnih,ijh->bnij', query_layer, relations_keys)
attention_scores = attention_scores + key_position_scores_r_t
elif (self.p_bias == 't5_relative') and hasattr(self, 'relative_positions_encoding'):
relations_keys = self.relative_positions(attention_scores.shape[-1], attention_scores.shape[-1])
key_position_scores_r_t = self.relative_positions_encoding(relations_keys).permute([2, 0, 1]).unsqueeze(0)
attention_scores = attention_scores + key_position_scores_r_t
# 是否进行attention scale
if self.attention_scale:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 执行attention mask,对于mask为0部分的attention mask,
# 值为-1e10,经过softmax后,attention_probs几乎为0,所以不会attention到mask为0的部分
if attention_mask is not None:
# attention_scores = attention_scores.masked_fill(attention_mask == 0, -1e10)
attention_mask = (1.0 - attention_mask) * -10000.0 # 所以传入的mask的非padding部分为1, padding部分为0
attention_scores = attention_scores + attention_mask
# 将attention score 归一化到0-1
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer) # [batch_size, num_attention_heads, query_len, attention_head_size]
if (self.p_bias == 'typical_relative') and hasattr(self, 'relative_positions_encoding'):
relations_values = self.relative_positions_encoding(attention_scores.shape[-1], attention_scores.shape[-1])
# 旧实现,方便读者理解维度转换
# attention_probs_t = attention_probs.permute(2, 0, 1, 3)
# attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads, to_seq_length)
# value_position_scores = torch.matmul(attentions_probs_r, relations_values)
# value_position_scores_r = value_position_scores.view(from_seq_length, batch_size, num_attention_heads, self.attention_head_size)
# value_position_scores_r_t = value_position_scores_r.permute(1, 2, 0, 3)
# 新实现
value_position_scores_r_t = torch.einsum('bnij,ijh->bnih', attention_probs, relations_values)
context_layer = context_layer + value_position_scores_r_t
# context_layer shape: [batch_size, query_len, num_attention_heads, attention_head_size]
# transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,
# 所以在调用view之前,需要contiguous来返回一个contiguous copy;
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# 是否返回attention scores
if self.return_attention_scores:
# 这里返回的attention_scores没有经过softmax, 可在外部进行归一化操作
return self.o(context_layer), attention_scores
else:
return self.o(context_layer)
class PositionWiseFeedForward(nn.Module):
def __init__(self, hidden_size, intermediate_size, dropout_rate=0.5, hidden_act='gelu', is_dropout=False, bias=True, **kwargs):
# 原生的tf版本的bert在激活函数后,没有添加dropout层,但是在google AI的bert-pytorch开源项目中,多了一层dropout;
# 并且在pytorch官方的TransformerEncoderLayer的实现中,也有一层dropout层,就像这样:self.linear2(self.dropout(self.activation(self.linear1(src))));
# 这样不统一做法的原因不得而知,不过有没有这一层,差别可能不会很大;
# 为了适配是否dropout,用is_dropout,dropout_rate两个参数控制;如果是实现原始的transformer,直接使用默认参数即可;如果是实现bert,则is_dropout为False,此时的dropout_rate参数并不会使用.
super(PositionWiseFeedForward, self).__init__()
self.is_dropout = is_dropout
self.intermediate_act_fn = get_activation(hidden_act)
self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=bias)
if self.is_dropout:
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
# x shape: (batch size, seq len, hidden_size)
if self.is_dropout:
x = self.dropout(self.intermediate_act_fn(self.intermediateDense(x)))
else:
x = self.intermediate_act_fn(self.intermediateDense(x))
# x shape: (batch size, seq len, intermediate_size)
x = self.outputDense(x)
# x shape: (batch size, seq len, hidden_size)
return x
class GatedAttentionUnit(nn.Module):
'''门控注意力单元,
链接:https://arxiv.org/abs/2202.10447
介绍:https://kexue.fm/archives/8934
说明:没有加入加性相对位置编码
参考pytorch项目:https://github.com/lucidrains/FLASH-pytorch
'''
def __init__(self, hidden_size, attention_key_size, intermediate_size, attention_probs_dropout_prob, hidden_act,
is_dropout=False, attention_scale=True, bias=True, normalization='softmax_plus', **kwargs):
super().__init__()
self.intermediate_size = intermediate_size
self.attention_head_size = attention_key_size
self.attention_scale = attention_scale
self.is_dropout = is_dropout
self.normalization = normalization
self.hidden_fn = get_activation(hidden_act)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
self.i_dense = nn.Linear(hidden_size, self.intermediate_size*2+attention_key_size, bias=bias)
self.offsetscale = self.OffsetScale(attention_key_size, heads=2, bias=bias)
self.o_dense = nn.Linear(self.intermediate_size, hidden_size, bias=bias)
self.a_bias, self.p_bias = kwargs.get('a_bias'), kwargs.get('p_bias')
if self.p_bias == 'rotary': # RoPE
self.relative_positions_encoding = RoPEPositionEncoding(max_position=kwargs.get('max_position'), embedding_size=self.attention_head_size)
def forward(self, hidden_states, attention_mask):
# 投影变换
hidden_states = self.hidden_fn(self.i_dense(hidden_states))
u, v, qk = hidden_states.split([self.intermediate_size, self.intermediate_size, self.attention_head_size], dim=-1)
q, k = self.offsetscale(qk) # 仿射变换
# 加入RoPE
if self.p_bias == 'rotary':
q = self.relative_positions_encoding(q)
k = self.relative_positions_encoding(k)
# Attention
attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) # [btz, seq_len, seq_len]
if self.attention_scale:
# seq_len = hidden_states.shape[1]
# attention_scores = F.relu(attention_scores/seq_len) ** 2
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
attention_mask = (1.0 - attention_mask) * -1e12
attention_scores = attention_scores + attention_mask.squeeze(1)
# 归一化
attention_scores = self.attention_normalize(attention_scores, -1, self.normalization)
if self.is_dropout:
attention_scores = self.dropout(attention_scores)
# 计算输出
out = self.o_dense(u * torch.einsum('b i j, b j d -> b i d', attention_scores, v))
return out
def attention_normalize(self, a, dim=-1, method='softmax'):
"""不同的注意力归一化方案
softmax:常规/标准的指数归一化;
squared_relu:来自 https://arxiv.org/abs/2202.10447 ;
softmax_plus:来自 https://kexue.fm/archives/8823 。
"""
if method == 'softmax':
return F.softmax(a, dim=dim)
else:
mask = (a > -1e11).float()
l = torch.maximum(torch.sum(mask, dim=dim, keepdims=True), torch.tensor(1).to(mask))
if method == 'squared_relu':
return F.relu(a)**2 / l
elif method == 'softmax_plus':
return F.softmax(a * torch.log(l) / torch.log(torch.tensor(512)).to(mask), dim=dim)
return a
class OffsetScale(nn.Module):
'''仿射变换
'''
def __init__(self, head_size, heads=1, bias=True):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, head_size))
self.bias = bias
if bias:
self.beta = nn.Parameter(torch.zeros(heads, head_size))
nn.init.normal_(self.gamma, std = 0.02)
def forward(self, x):
out = torch.einsum('... d, h d -> ... h d', x, self.gamma)
if self.bias:
out = out + self.beta
return out.unbind(dim = -2)
class BertEmbeddings(nn.Module):
"""
embeddings层
构造word, position and token_type embeddings.
"""
def __init__(self, vocab_size, embedding_size, hidden_size, max_position, segment_vocab_size, shared_segment_embeddings, drop_rate, conditional_size=False, **kwargs):
super(BertEmbeddings, self).__init__()
self.shared_segment_embeddings = shared_segment_embeddings
self.word_embeddings = nn.Embedding(vocab_size, embedding_size, padding_idx=0)
# 位置编码
if kwargs.get('p_bias') == 'sinusoid':
self.position_embeddings = SinusoidalPositionEncoding(max_position, embedding_size)
elif kwargs.get('p_bias') in {'rotary', 'typical_relative', 't5_relative', 'other_relative'}:
# 如果使用相对位置编码,则不声明PositionEmbeddings
pass
elif max_position > 0:
self.position_embeddings = nn.Embedding(max_position, embedding_size)
# segement编码
if (segment_vocab_size > 0) and (not shared_segment_embeddings):
self.segment_embeddings = nn.Embedding(segment_vocab_size, embedding_size)
# emb_scale
self.emb_scale = kwargs.get('emb_scale', 1) # transform_xl, xlnet特有
# LayerNorm
self.layerNorm = LayerNorm(embedding_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.dropout = nn.Dropout(drop_rate)
# 如果embedding_size != hidden_size,则再有一个linear(适用于albert矩阵分解)
if embedding_size != hidden_size:
self.embedding_hidden_mapping_in = nn.Linear(embedding_size, hidden_size)
def forward(self, token_ids, segment_ids=None, conditional_emb=None, additional_embs=None):
if (not token_ids.requires_grad) and (token_ids.dtype in {torch.long, torch.int}):
words_embeddings = self.word_embeddings(token_ids)
else:
words_embeddings = token_ids # 自定义word_embedding,目前仅有VAT中使用
if hasattr(self, 'segment_embeddings'):
segment_ids = torch.zeros_like(token_ids) if segment_ids is None else segment_ids
segment_embeddings = self.segment_embeddings(segment_ids)
embeddings = words_embeddings + segment_embeddings
elif self.shared_segment_embeddings: # segment和word_embedding共享权重
segment_ids = torch.zeros_like(token_ids) if segment_ids is None else segment_ids
segment_embeddings = self.word_embeddings(segment_ids)
embeddings = words_embeddings + segment_embeddings
else:
embeddings = words_embeddings
# 额外的embedding,如词性等
if additional_embs is not None:
for emb in additional_embs:
embeddings += emb
if hasattr(self, 'position_embeddings'):
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
position_ids = position_ids.unsqueeze(0).repeat(token_ids.shape[0], 1)
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
if self.emb_scale != 1:
embeddings = embeddings * self.emb_scale # transform_xl, xlnet特有
if hasattr(self, 'layerNorm'):
embeddings = self.layerNorm((embeddings, conditional_emb))
embeddings = self.dropout(embeddings)
if hasattr(self, 'embedding_hidden_mapping_in'):
embeddings = self.embedding_hidden_mapping_in(embeddings)
return embeddings
class BertLayer(nn.Module):
"""
Transformer层:
顺序为: Attention --> Add --> LayerNorm --> Feed Forward --> Add --> LayerNorm
注意: 1、以上都不计dropout层,并不代表没有dropout,每一层的dropout使用略有不同,注意区分
2、原始的Transformer的encoder中的Feed Forward层一共有两层linear,
config.intermediate_size的大小不仅是第一层linear的输出尺寸,也是第二层linear的输入尺寸
"""
def __init__(self, hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act,
is_dropout=False, conditional_size=False, **kwargs):
super(BertLayer, self).__init__()
self.multiHeadAttention = MultiHeadAttentionLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs)
self.dropout1 = nn.Dropout(dropout_rate)
self.layerNorm1 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.feedForward = PositionWiseFeedForward(hidden_size, intermediate_size, dropout_rate, hidden_act, is_dropout=is_dropout, **kwargs)
self.dropout2 = nn.Dropout(dropout_rate)
self.layerNorm2 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
self.is_decoder = kwargs.get('is_decoder')
if self.is_decoder:
self.crossAttention = MultiHeadAttentionLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, **kwargs)
self.dropout3 = nn.Dropout(dropout_rate)
self.layerNorm3 = LayerNorm(hidden_size, eps=1e-12, conditional_size=conditional_size, **kwargs)
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
self_attn_output = self.multiHeadAttention(hidden_states, attention_mask) # self.decoder为true时候,这里的attention_mask是三角的
hidden_states = hidden_states + self.dropout1(self_attn_output)
hidden_states = self.layerNorm1((hidden_states, conditional_emb))
# cross attention
if self.is_decoder and encoder_hidden_states is not None:
cross_attn_output = self.crossAttention(hidden_states, None, encoder_hidden_states, encoder_attention_mask)
hidden_states = hidden_states + self.dropout3(cross_attn_output)
hidden_states = self.layerNorm3((hidden_states, conditional_emb))
self_attn_output2 = self.feedForward(hidden_states)
hidden_states = hidden_states + self.dropout2(self_attn_output2)
hidden_states = self.layerNorm2((hidden_states, conditional_emb))
return hidden_states
class T5Layer(BertLayer):
"""T5的Encoder的主体是基于Self-Attention的模块
顺序:LN --> Att --> Add --> LN --> FFN --> Add
"""
def __init__(self, *args, version='t5.1.0', **kwargs):
super().__init__(*args, **kwargs)
# 如果是t5.1.1结构,则FFN层需要变更
if version.endswith('t5.1.1'):
kwargs['dropout_rate'] = args[2]
kwargs['hidden_act'] = args[5]
self.feedForward = self.T5PositionWiseFeedForward(hidden_size=args[0], intermediate_size=args[4], **kwargs)
# decoder中间有crossAttention
if self.is_decoder and hasattr(self.crossAttention, 'relative_positions_encoding'):
del self.crossAttention.relative_positions_encoding
del self.crossAttention.relative_positions
def forward(self, hidden_states, attention_mask, conditional_emb=None, encoder_hidden_states=None, encoder_attention_mask=None):
# bert的layernorm是在attn/ffc之后,Openai-gpt2是在之前
x = self.layerNorm1((hidden_states, conditional_emb))
self_attn_output = self.multiHeadAttention(x, attention_mask)
hidden_states = hidden_states + self.dropout1(self_attn_output)
# cross attention
if self.is_decoder and encoder_hidden_states is not None:
x = self.layerNorm3((hidden_states, conditional_emb))
cross_attn_output = self.crossAttention(x, None, encoder_hidden_states, encoder_attention_mask)
hidden_states = hidden_states + self.dropout3(cross_attn_output)
x = self.layerNorm2((hidden_states, conditional_emb))
ffn_output = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(ffn_output)
return hidden_states
class T5PositionWiseFeedForward(PositionWiseFeedForward):
'''参考transformer包: https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
'''
def __init__(self, hidden_size, intermediate_size, **kwargs):
super().__init__(hidden_size, intermediate_size, **kwargs)
self.intermediateDense = nn.Linear(hidden_size, intermediate_size, bias=False)
self.intermediateDense1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.outputDense = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
# x shape: (batch size, seq len, hidden_size)
x_gelu = self.intermediate_act_fn(self.intermediateDense(x))
x_linear = self.intermediateDense1(x)
x = x_gelu * x_linear
if self.is_dropout:
x = self.dropout(x)
# x shape: (batch size, seq len, intermediate_size)
x = self.outputDense(x)
# x shape: (batch size, seq len, hidden_size)
return x
class XlnetLayer(BertLayer):
'''Transformer_XL层
顺序为: Attention --> Add --> LayerNorm --> Feed Forward --> Add --> LayerNorm
'''
def __init__(self, hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act, **kwargs):
super().__init__(hidden_size, num_attention_heads, dropout_rate, attention_probs_dropout_prob, intermediate_size, hidden_act, **kwargs)
self.pre_lnorm = kwargs.get('pre_lnorm')
# multiattn层无bias
self.multiHeadAttention = self.RelPartialLearnableMultiHeadAttn(hidden_size, num_attention_heads, attention_probs_dropout_prob, bias=False, **kwargs)
def forward(self, hidden_states, segment_ids, pos_emb, attention_mask, mems_i, conditional_emb=None):
# 拼接mems和query,mems_i: [btz, m_len, hdsz], w: [btz, q_len, hdsz] = [btz, k_len, hdsz]
hidden_states_cat = torch.cat([mems_i, hidden_states], 1) if mems_i is not None else hidden_states
# Attn
if self.pre_lnorm:
hidden_states_cat = self.layerNorm1((hidden_states_cat, conditional_emb))
self_attn_output = self.multiHeadAttention(hidden_states, hidden_states_cat, pos_emb, attention_mask, segment_ids)
hidden_states = hidden_states + self.dropout1(self_attn_output)
if not self.pre_lnorm: # post_lnorm
hidden_states = self.layerNorm1((hidden_states, conditional_emb))
# FFN
x = self.layerNorm2((hidden_states, conditional_emb)) if self.pre_lnorm else hidden_states
self_attn_output2 = self.feedForward(x)
hidden_states = hidden_states + self.dropout2(self_attn_output2)
if not self.pre_lnorm: # post_lnorm
hidden_states = self.layerNorm2((hidden_states, conditional_emb))
return hidden_states
class RelPartialLearnableMultiHeadAttn(MultiHeadAttentionLayer):
'''Transformer_XL式相对位置编码, 这里修改成了MultiHeadAttentionLayer的batch_first代码格式
'''
def __init__(self, *args, r_w_bias=None, r_r_bias=None, r_s_bias=None, **kwargs):
super().__init__(*args, **kwargs)
segment_vocab_size = kwargs.get('segment_vocab_size')
if r_r_bias is None or r_w_bias is None: # Biases are not shared
self.r_r_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局内容偏置
self.r_w_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局位置偏置
if segment_vocab_size > 0:
self.r_s_bias = nn.Parameter(torch.FloatTensor(self.num_attention_heads, self.attention_head_size)) # 全局segment偏置
else: # 所有层公用一个
self.r_r_bias = r_r_bias
self.r_w_bias = r_w_bias
self.r_s_bias = r_s_bias
if segment_vocab_size > 0:
# self.seg_embed = nn.Embedding(segment_vocab_size, self.hidden_size)
self.seg_embed = nn.Parameter(torch.FloatTensor(segment_vocab_size, self.num_attention_heads, self.attention_head_size))
self.r = nn.Linear(self.hidden_size, self.hidden_size, bias=self.bias)
self.rel_shift_opt = kwargs.get('rel_shift_opt')
@staticmethod
def rel_shift(x, zero_triu=False):
'''transformer_xl使用, 向左shift让右上角都是0, 对角线是同一个值, x: [btz, n_head, q_len, k_len]
'''
q_len, k_len = x.size(2), x.size(-1)
zero_pad = torch.zeros((*x.size()[:2], q_len, 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(*x.size()[:2], k_len + 1, q_len)
x = x_padded[:,:,1:,:].view_as(x)
if zero_triu:
ones = torch.ones((q_len, k_len), device=x.device)
x = x * torch.tril(ones, k_len - q_len)[None,None,:,:]
return x
@staticmethod
def rel_shift_bnij(x, klen=-1):
''' xlnet使用
'''
x_size = x.shape
x = x.reshape(x_size[0], x_size[1], x_size[3], x_size[2])
x = x[:, :, 1:, :]
x = x.reshape(x_size[0], x_size[1], x_size[2], x_size[3] - 1)
x = torch.index_select(x, 3, torch.arange(klen, device=x.device, dtype=torch.long))
# x = x[:, :, :, :klen]
return x
def forward(self, w, cat, r, attention_mask=None, seg_mat=None):
# w: 词向量[btz, q_len, hdsz], cat: w和mem_i拼接后向量[btz, k_len, hdsz], r:相对位置向量[r_len, hdsz]
qlen, rlen, bsz = w.size(1), r.size(0), w.size(0)
mixed_query_layer = self.q(cat)[:, -qlen:, :] # 仅取用query部分,不适用mem部分
mixed_key_layer = self.k(cat)
mixed_value_layer = self.v(cat)
w_head_q = self.transpose_for_scores(mixed_query_layer) # [btz, n_head, q_len, d_head]
w_head_k = self.transpose_for_scores(mixed_key_layer) # [btz, n_head, k_len, d_head]
w_head_v = self.transpose_for_scores(mixed_value_layer) # [btz, n_head, k_len, d_head]
r_head_k = self.r(r) # [hdsz, nhead*headsize] = [r_len, 1, nhead*headsize]
r_head_k = r_head_k.view(rlen, self.num_attention_heads, self.attention_head_size) # rlen x n_head x d_head
#### compute attention score
rw_head_q = w_head_q + self.r_w_bias.unsqueeze(1) # [btz, n_head, q_len, d_head]
AC = torch.einsum('bnid,bnjd->bnij', (rw_head_q, w_head_k)) # [btz, n_head, q_len, k_len]
rr_head_q = w_head_q + self.r_r_bias.unsqueeze(1) # [btz, n_head, q_len, d_head]
BD = torch.einsum('bnid,jnd->bnij', (rr_head_q, r_head_k)) # [btz, n_head, q_len, k_len]
BD = self.rel_shift_bnij(BD, klen=AC.shape[3]) if self.rel_shift_opt == 'xlnet' else self.rel_shift(BD)
if hasattr(self, 'seg_embed') and (self.r_r_bias is not None):
# # 之前的方式,需要配合Embedding,以及load_variable和variable_mapping,显存容易爆炸
# w_head_s = self.seg_embed(seg_mat) # [btz, q_len, klen, hdsz]
# w_head_s = w_head_s.reshape(*w_head_s.shape[:3], self.num_attention_heads, self.attention_head_size)
# rs_head_q = w_head_q + self.r_s_bias.unsqueeze(1)
# EF = torch.einsum('bnid,bijnd->bnij', (rs_head_q, w_head_s)) # [btz, n_head, q_len, k_len]
seg_mat = F.one_hot(seg_mat, 2).float()
EF = torch.einsum("bnid,snd->ibns", w_head_q + self.r_s_bias.unsqueeze(1), self.seg_embed)
EF = torch.einsum("bijs,ibns->bnij", seg_mat, EF)
else:
EF = 0
# # [btz, n_head, q_len, k_len]
attention_scores = AC + BD + EF
if self.attention_scale:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
#### compute attention probability
if attention_mask is not None and attention_mask.any().item():
# attention_mask = (1.0 - attention_mask) * -10000.0
# attention_scores = attention_scores + attention_mask # 这里修改了下,原有的-10000不够接近-inf
attention_mask = (1.0 - attention_mask)
attention_scores = attention_scores.float().masked_fill(attention_mask.bool(), -1e30).type_as(attention_mask)
# [btz, n_head, q_len, k_len]
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, w_head_v) # [batch_size, num_attention_heads, query_len, attention_head_size]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.view(*new_context_layer_shape)
# 是否返回attention scores
if self.return_attention_scores:
# 这里返回的attention_scores没有经过softmax, 可在外部进行归一化操作
return self.o(context_layer), attention_scores
else:
return self.o(context_layer)
class AdaptiveEmbedding(nn.Module):
'''Transformer_XL的自适应embedding, 实现不同区间使用不同的维度
可以实现如高频词用比如1024或512维,低频词用256或64维, 再用Linear层project到相同的维数
'''
def __init__(self, vocab_size, embedding_size, hidden_size, cutoffs, div_val=1, sample_softmax=False, **kwargs):
super().__init__()
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.cutoffs = cutoffs + [vocab_size]
self.div_val = div_val
self.hidden_size = hidden_size
self.emb_scale = hidden_size ** 0.5
self.cutoff_ends = [0] + self.cutoffs
self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(nn.Embedding(vocab_size, embedding_size, sparse=sample_softmax > 0))
if hidden_size != embedding_size:
self.emb_projs.append(nn.Parameter(torch.FloatTensor(hidden_size, embedding_size)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = embedding_size // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.FloatTensor(hidden_size, d_emb_i)))
def forward(self, token_ids):
if self.div_val == 1: # 仅有一个embedding
embed = self.emb_layers[0](token_ids) # [btz, seq_len, embedding_size]
if self.hidden_size != self.embedding_size:
embed = nn.functional.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = token_ids.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.hidden_size], dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
indices_i = mask_i.nonzero().squeeze()
if indices_i.numel() == 0:
continue
inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = nn.functional.linear(emb_i, self.emb_projs[i])
emb_flat.index_copy_(0, indices_i, emb_i)
embed_shape = token_ids.size() + (self.hidden_size,)
embed = emb_flat.view(embed_shape)
embed.mul_(self.emb_scale)
return embed
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super(Identity, self).__init__()
def forward(self, *args):
return args[0]
class XlnetPositionsEncoding(nn.Module):
'''Xlnet, transformer_xl使用的相对位置编码
和SinusoidalPositionEncoding区别是一个是间隔排列, 一个是前后排列
'''
def __init__(self, embedding_size):
super().__init__()
self.demb = embedding_size
inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_size, 2.0) / embedding_size))
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pos_emb
class RelativePositionsEncoding(nn.Module):
"""nezha用的google相对位置编码
来自论文:https://arxiv.org/abs/1803.02155
"""
def __init__(self, qlen, klen, embedding_size, max_relative_position=127):
super(RelativePositionsEncoding, self).__init__()
# 生成相对位置矩阵
vocab_size = max_relative_position * 2 + 1
distance_mat = torch.arange(klen)[None, :] - torch.arange(qlen)[:, None] # 列数-行数, [query_len, key_len]
distance_mat_clipped = torch.clamp(distance_mat, -max_relative_position, max_relative_position)
final_mat = distance_mat_clipped + max_relative_position
# sinusoid_encoding编码的位置矩阵
embeddings_table = get_sinusoid_encoding_table(vocab_size, embedding_size)
# 实现方式1
# flat_relative_positions_matrix = final_mat.view(-1)
# one_hot_relative_positions_matrix = torch.nn.functional.one_hot(flat_relative_positions_matrix, num_classes=vocab_size).float()
# position_embeddings = torch.matmul(one_hot_relative_positions_matrix, embeddings_table)
# my_shape = list(final_mat.size())
# my_shape.append(embedding_size)
# position_embeddings = position_embeddings.view(my_shape)
# 实现方式2
# position_embeddings = take_along_dim(embeddings_table, final_mat.flatten().unsqueeze(1), dim=0)
# position_embeddings = position_embeddings.reshape(*final_mat.shape, embeddings_table.shape[-1]) # [seq_len, seq_len, hdsz]
# self.register_buffer('position_embeddings', position_embeddings)
# 实现方式3
position_embeddings = nn.Embedding.from_pretrained(embeddings_table, freeze=True)(final_mat)
self.register_buffer('position_embeddings', position_embeddings)
def forward(self, qlen, klen):
return self.position_embeddings[:qlen, :klen, :]
class RelativePositionsEncodingT5(nn.Module):
"""Google T5的相对位置编码
来自论文:https://arxiv.org/abs/1910.10683
"""
def __init__(self, qlen, klen, relative_attention_num_buckets, is_decoder=False):
super(RelativePositionsEncodingT5, self).__init__()
# 生成相对位置矩阵
context_position = torch.arange(qlen, dtype=torch.long)[:, None]
memory_position = torch.arange(klen, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (qlen, klen)
relative_position = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=not is_decoder,
num_buckets=relative_attention_num_buckets,
)
self.register_buffer('relative_position', relative_position)
def forward(self, qlen, klen):
return self.relative_position[:qlen, :klen]
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
'''直接来源于transformer
'''
ret = 0
n = -relative_position
if bidirectional:
num_buckets //= 2
ret += (n < 0).to(torch.long) * num_buckets # mtf.to_int32(mtf.less(n, 0)) * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
# now n is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = n < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).to(torch.long)
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
class SinusoidalPositionEncoding(nn.Module):
"""定义Sin-Cos位置Embedding
"""
def __init__(self, max_position, embedding_size):
super(SinusoidalPositionEncoding, self).__init__()
self.position_embeddings = nn.Embedding.from_pretrained(get_sinusoid_encoding_table(max_position, embedding_size), freeze=True)
def forward(self, position_ids):
return self.position_embeddings(position_ids)
class RoPEPositionEncoding(nn.Module):
"""旋转式位置编码: https://kexue.fm/archives/8265
"""
def __init__(self, max_position, embedding_size):
super(RoPEPositionEncoding, self).__init__()
position_embeddings = get_sinusoid_encoding_table(max_position, embedding_size) # [seq_len, hdsz]
cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1)
sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1)
# register_buffer是为了最外层model.to(device),不用内部指定device
self.register_buffer('cos_position', cos_position)
self.register_buffer('sin_position', sin_position)
def forward(self, qw, seq_dim=-2):
# 默认最后两个维度为[seq_len, hdsz]
seq_len = qw.shape[seq_dim]
qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw)
return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len]
class CRF(nn.Module):
'''Conditional random field: https://github.com/lonePatient/BERT-NER-Pytorch/blob/master/models/layers/crf.py
'''
def __init__(self, num_tags: int, init_transitions: Optional[List[np.ndarray]] = None, freeze=False) -> None:
if num_tags <= 0:
raise ValueError(f'invalid number of tags: {num_tags}')
super().__init__()
self.num_tags = num_tags
if (init_transitions is None) and (not freeze):
self.start_transitions = nn.Parameter(torch.empty(num_tags))
self.end_transitions = nn.Parameter(torch.empty(num_tags))
self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
nn.init.uniform_(self.start_transitions, -0.1, 0.1)
nn.init.uniform_(self.end_transitions, -0.1, 0.1)
nn.init.uniform_(self.transitions, -0.1, 0.1)
elif init_transitions is not None:
transitions = torch.tensor(init_transitions[0], dtype=torch.float)
start_transitions = torch.tensor(init_transitions[1], dtype=torch.float)
end_transitions = torch.tensor(init_transitions[2], dtype=torch.float)
if not freeze:
self.transitions = nn.Parameter(transitions)
self.start_transitions = nn.Parameter(start_transitions)
self.end_transitions = nn.Parameter(end_transitions)
else:
self.register_buffer('transitions', transitions)
self.register_buffer('start_transitions', start_transitions)
self.register_buffer('end_transitions', end_transitions)
def __repr__(self) -> str:
return f'{self.__class__.__name__}(num_tags={self.num_tags})'
def forward(self, emissions: torch.Tensor, mask: torch.ByteTensor,
tags: torch.LongTensor, reduction: str = 'mean') -> torch.Tensor:
"""Compute the conditional log likelihood of a sequence of tags given emission scores.
emissions: [btz, seq_len, num_tags]
mask: [btz, seq_len]
tags: [btz, seq_len]
"""
if reduction not in ('none', 'sum', 'mean', 'token_mean'):
raise ValueError(f'invalid reduction: {reduction}')
if mask.dtype != torch.uint8:
mask = mask.byte()
self._validate(emissions, tags=tags, mask=mask)
# shape: (batch_size,)
numerator = self._compute_score(emissions, tags, mask)
# shape: (batch_size,)
denominator = self._compute_normalizer(emissions, mask)
# shape: (batch_size,)
llh = denominator - numerator
if reduction == 'none':
return llh
if reduction == 'sum':
return llh.sum()
if reduction == 'mean':
return llh.mean()
return llh.sum() / mask.float().sum()
def decode(self, emissions: torch.Tensor, mask: Optional[torch.ByteTensor] = None,
nbest: Optional[int] = None, pad_tag: Optional[int] = None) -> List[List[List[int]]]:
"""Find the most likely tag sequence using Viterbi algorithm.
"""
if nbest is None:
nbest = 1
if mask is None:
mask = torch.ones(emissions.shape[:2], dtype=torch.uint8, device=emissions.device)
if mask.dtype != torch.uint8:
mask = mask.byte()
self._validate(emissions, mask=mask)
best_path = self._viterbi_decode_nbest(emissions, mask, nbest, pad_tag)
return best_path[0] if nbest == 1 else best_path
def _validate(self, emissions: torch.Tensor, tags: Optional[torch.LongTensor] = None,
mask: Optional[torch.ByteTensor] = None) -> None:
if emissions.dim() != 3:
raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
if emissions.size(2) != self.num_tags:
raise ValueError(f'expected last dimension of emissions is {self.num_tags}, '
f'got {emissions.size(2)}')
if tags is not None:
if emissions.shape[:2] != tags.shape:
raise ValueError('the first two dimensions of emissions and tags must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
if mask is not None:
if emissions.shape[:2] != mask.shape:
raise ValueError('the first two dimensions of emissions and mask must match, '
f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
no_empty_seq_bf = mask[:, 0].all()
if not no_empty_seq_bf:
raise ValueError('mask of the first timestep must all be on')
def _compute_score(self, emissions: torch.Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (batch_size, seq_length, num_tags)
# tags: (batch_size, seq_length)
# mask: (batch_size, seq_length)
batch_size, seq_length = tags.shape
mask = mask.float()
# Start transition score and first emission
# shape: (batch_size,)
score = self.start_transitions[tags[:, 0]]
score += emissions[torch.arange(batch_size), 0, tags[:, 0]]
for i in range(1, seq_length):
# Transition score to next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += self.transitions[tags[:, i - 1], tags[:, i]] * mask[:, i]
# Emission score for next tag, only added if next timestep is valid (mask == 1)
# shape: (batch_size,)
score += emissions[torch.arange(batch_size), i, tags[:, i]] * mask[:, i]
# End transition score
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=1) - 1
# shape: (batch_size,)
last_tags = tags[torch.arange(batch_size), seq_ends]
# shape: (batch_size,)
score += self.end_transitions[last_tags]
return score
def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
# emissions: (batch_size, seq_length, num_tags)
# mask: (batch_size, seq_length)
seq_length = emissions.size(1)
# Start transition score and first emission; score has size of
# (batch_size, num_tags) where for each batch, the j-th column stores
# the score that the first timestep has tag j
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[:, 0]
for i in range(1, seq_length):
# Broadcast score for every possible next tag
# shape: (batch_size, num_tags, 1)
broadcast_score = score.unsqueeze(2)
# Broadcast emission score for every possible current tag
# shape: (batch_size, 1, num_tags)
broadcast_emissions = emissions[:, i].unsqueeze(1)
# Compute the score tensor of size (batch_size, num_tags, num_tags) where
# for each sample, entry at row i and column j stores the sum of scores of all
# possible tag sequences so far that end with transitioning from tag i to tag j
# and emitting
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emissions
# Sum over all possible current tags, but we're in score space, so a sum
# becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
# all possible tag sequences so far, that end in tag i
# shape: (batch_size, num_tags)
next_score = torch.logsumexp(next_score, dim=1)
# Set score to the next score if this timestep is valid (mask == 1)
# shape: (batch_size, num_tags)
score = torch.where(mask[:, i].unsqueeze(1).bool(), next_score, score)
# End transition score
# shape: (batch_size, num_tags)
score += self.end_transitions
# Sum (log-sum-exp) over all possible tags
# shape: (batch_size,)
return torch.logsumexp(score, dim=1)
def _viterbi_decode_nbest(self, emissions: torch.FloatTensor, mask: torch.ByteTensor,
nbest: int, pad_tag: Optional[int] = None) -> List[List[List[int]]]:
# emissions: (batch_size, seq_length, num_tags)
# mask: (batch_size, seq_length)
# return: (nbest, batch_size, seq_length)
if pad_tag is None:
pad_tag = 0
device = emissions.device
batch_size, seq_length = mask.shape
# Start transition and first emission
# shape: (batch_size, num_tags)
score = self.start_transitions + emissions[:, 0]
history_idx = torch.zeros((batch_size, seq_length, self.num_tags, nbest), dtype=torch.long, device=device)
oor_idx = torch.zeros((batch_size, self.num_tags, nbest), dtype=torch.long, device=device)
oor_tag = torch.full((batch_size, seq_length, nbest), pad_tag, dtype=torch.long, device=device)
# - score is a tensor of size (batch_size, num_tags) where for every batch,
# value at column j stores the score of the best tag sequence so far that ends
# with tag j
# - history_idx saves where the best tags candidate transitioned from; this is used
# when we trace back the best tag sequence
# - oor_idx saves the best tags candidate transitioned from at the positions
# where mask is 0, i.e. out of range (oor)
# Viterbi algorithm recursive case: we compute the score of the best tag sequence
# for every possible next tag
for i in range(1, seq_length):
if i == 1:
broadcast_score = score.unsqueeze(-1)
broadcast_emission = emissions[:, i].unsqueeze(1)
# shape: (batch_size, num_tags, num_tags)
next_score = broadcast_score + self.transitions + broadcast_emission
else:
broadcast_score = score.unsqueeze(-1)
broadcast_emission = emissions[:, i].unsqueeze(1).unsqueeze(2)
# shape: (batch_size, num_tags, nbest, num_tags)
next_score = broadcast_score + self.transitions.unsqueeze(1) + broadcast_emission
# Find the top `nbest` maximum score over all possible current tag
# shape: (batch_size, nbest, num_tags)
next_score, indices = next_score.view(batch_size, -1, self.num_tags).topk(nbest, dim=1)
if i == 1:
score = score.unsqueeze(-1).expand(-1, -1, nbest)
indices = indices * nbest
# convert to shape: (batch_size, num_tags, nbest)
next_score = next_score.transpose(2, 1)
indices = indices.transpose(2, 1)
# Set score to the next score if this timestep is valid (mask == 1)
# and save the index that produces the next score
# shape: (batch_size, num_tags, nbest)
score = torch.where(mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), next_score, score)
indices = torch.where(mask[:, i].unsqueeze(-1).unsqueeze(-1).bool(), indices, oor_idx)
history_idx[:, i - 1] = indices
# End transition score shape: (batch_size, num_tags, nbest)
end_score = score + self.end_transitions.unsqueeze(-1)
_, end_tag = end_score.view(batch_size, -1).topk(nbest, dim=1)
# shape: (batch_size,)
seq_ends = mask.long().sum(dim=1) - 1
# insert the best tag at each sequence end (last position with mask == 1)
history_idx.scatter_(1, seq_ends.view(-1, 1, 1, 1).expand(-1, 1, self.num_tags, nbest),
end_tag.view(-1, 1, 1, nbest).expand(-1, 1, self.num_tags, nbest))
# The most probable path for each sequence
best_tags_arr = torch.zeros((batch_size, seq_length, nbest), dtype=torch.long, device=device)
best_tags = torch.arange(nbest, dtype=torch.long, device=device).view(1, -1).expand(batch_size, -1)
for idx in range(seq_length - 1, -1, -1):
best_tags = torch.gather(history_idx[:, idx].view(batch_size, -1), 1, best_tags)
best_tags_arr[:, idx] = torch.div(best_tags.data.view(batch_size, -1), nbest, rounding_mode='floor')
return torch.where(mask.unsqueeze(-1).bool(), best_tags_arr, oor_tag).permute(2, 0, 1)
class BERT_WHITENING():
def __init__(self):
self.kernel = None
self.bias = None
def compute_kernel_bias(self, sentence_vec):
'''bert-whitening的torch实现
'''
vecs = torch.cat(sentence_vec, dim=0)
self.bias = -vecs.mean(dim=0, keepdims=True)
cov = torch.cov(vecs.T) # 协方差
u, s, vh = torch.linalg.svd(cov)
W = torch.matmul(u, torch.diag(s**0.5))
self.kernel = torch.linalg.inv(W.T)
def save_whiten(self, path):
whiten = {'kernel': self.kernel, 'bias': self.bias}
torch.save(path, whiten)
def load_whiten(self, path):
whiten = torch.load(path)
self.kernel = whiten['kernel']
self.bias = whiten['bias']
def transform_and_normalize(self, vecs):
"""应用变换,然后标准化
"""
if not (self.kernel is None or self.bias is None):
vecs = (vecs + self.bias).mm(self.kernel)
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
class GlobalPointer(nn.Module):
"""全局指针模块
将序列的每个(start, end)作为整体来进行判断
参考:https://kexue.fm/archives/8373
"""
def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True):
super().__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
self.tril_mask = tril_mask
self.RoPE = RoPE
self.dense = nn.Linear(hidden_size, heads * head_size * 2, bias=use_bias)
if self.RoPE:
self.position_embedding = RoPEPositionEncoding(max_len, head_size)
def forward(self, inputs, mask=None):
''' inputs: [..., hdsz]
mask: [bez, seq_len], padding部分为0
'''
sequence_output = self.dense(inputs) # [..., heads*head_size*2]
sequence_output = torch.stack(torch.chunk(sequence_output, self.heads, dim=-1), dim=-2) # [..., heads, head_size*2]
qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., heads, head_size]
# ROPE编码
if self.RoPE:
qw = self.position_embedding(qw)
kw = self.position_embedding(kw)
# 计算内积
logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw) # [btz, heads, seq_len, seq_len]
# 排除padding
if mask is not None:
attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1]
attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len]
logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf'))
logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf'))
# 排除下三角
if self.tril_mask:
logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12
# scale返回
return logits / self.head_size**0.5
class EfficientGlobalPointer(nn.Module):
"""更加参数高效的GlobalPointer
参考:https://kexue.fm/archives/8877
"""
def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True):
super().__init__()
self.heads = heads
self.head_size = head_size
self.RoPE = RoPE
self.tril_mask = tril_mask
self.RoPE = RoPE
self.p_dense = nn.Linear(hidden_size, head_size * 2, bias=use_bias)
self.q_dense = nn.Linear(head_size * 2, heads * 2, bias=use_bias)
if self.RoPE:
self.position_embedding = RoPEPositionEncoding(max_len, head_size)
def forward(self, inputs, mask=None):
''' inputs: [..., hdsz]
mask: [bez, seq_len], padding部分为0
'''
sequence_output = self.p_dense(inputs) # [..., head_size*2]
qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., head_size]
# ROPE编码
if self.RoPE:
qw = self.position_embedding(qw)
kw = self.position_embedding(kw)
# 计算内积
logits = torch.einsum('bmd,bnd->bmn', qw, kw) / self.head_size**0.5 # [btz, seq_len, seq_len], 是否是实体的打分
bias_input = self.q_dense(sequence_output) # [..., heads*2]
bias = torch.stack(torch.chunk(bias_input, self.heads, dim=-1), dim=-2).transpose(1,2) # [btz, heads, seq_len, 2]
logits = logits.unsqueeze(1) + bias[..., :1] + bias[..., 1:].transpose(2, 3) # [btz, heads, seq_len, seq_len]
# 排除padding
if mask is not None:
attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1]
attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len]
logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf'))
logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf'))
# 排除下三角
if self.tril_mask:
logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12
return logits
class TplinkerHandshakingKernel(nn.Module):
'''Tplinker的HandshakingKernel实现
'''
def __init__(self, hidden_size, shaking_type, inner_enc_type=''):
super().__init__()
self.shaking_type = shaking_type
if shaking_type == "cat":
self.combine_fc = nn.Linear(hidden_size * 2, hidden_size)
elif shaking_type == "cat_plus":
self.combine_fc = nn.Linear(hidden_size * 3, hidden_size)
elif shaking_type == "cln":
self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
elif shaking_type == "cln_plus":
self.tp_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
self.inner_context_cln = LayerNorm(hidden_size, conditional_size=hidden_size)
self.inner_enc_type = inner_enc_type
if inner_enc_type == "mix_pooling":
self.lamtha = nn.Parameter(torch.rand(hidden_size))
elif inner_enc_type == "lstm":
self.inner_context_lstm = nn.LSTM(hidden_size, hidden_size, num_layers=1, bidirectional=False, batch_first=True)
# 自行实现的用torch.gather方式来做,避免循环,目前只实现了cat方式
# tag_ids = [(i, j) for i in range(maxlen) for j in range(maxlen) if j >= i]
# gather_idx = torch.tensor(tag_ids, dtype=torch.long).flatten()[None, :, None]
# self.register_buffer('gather_idx', gather_idx)
def enc_inner_hiddens(self, seq_hiddens, inner_enc_type="lstm"):
# seq_hiddens: (batch_size, seq_len, hidden_size)
def pool(seqence, pooling_type):
if pooling_type == "mean_pooling":
pooling = torch.mean(seqence, dim = -2)
elif pooling_type == "max_pooling":
pooling, _ = torch.max(seqence, dim = -2)
elif pooling_type == "mix_pooling":
pooling = self.lamtha * torch.mean(seqence, dim = -2) + (1 - self.lamtha) * torch.max(seqence, dim = -2)[0]
return pooling
if "pooling" in inner_enc_type:
inner_context = torch.stack([pool(seq_hiddens[:, :i+1, :], inner_enc_type) for i in range(seq_hiddens.size()[1])], dim = 1)
elif inner_enc_type == "lstm":
inner_context, _ = self.inner_context_lstm(seq_hiddens)
return inner_context
def forward(self, seq_hiddens):
'''
seq_hiddens: (batch_size, seq_len, hidden_size)
return:
shaking_hiddenss: (batch_size, (1 + seq_len) * seq_len / 2, hidden_size) (32, 5+4+3+2+1, 5)
'''
seq_len = seq_hiddens.size()[-2]
shaking_hiddens_list = []
for ind in range(seq_len):
hidden_each_step = seq_hiddens[:, ind, :]
visible_hiddens = seq_hiddens[:, ind:, :] # ind: only look back
repeat_hiddens = hidden_each_step[:, None, :].repeat(1, seq_len - ind, 1)
if self.shaking_type == "cat":
shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens], dim = -1)
shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))
elif self.shaking_type == "cat_plus":
inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)
shaking_hiddens = torch.cat([repeat_hiddens, visible_hiddens, inner_context], dim = -1)
shaking_hiddens = torch.tanh(self.combine_fc(shaking_hiddens))
elif self.shaking_type == "cln":
shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens])
elif self.shaking_type == "cln_plus":
inner_context = self.enc_inner_hiddens(visible_hiddens, self.inner_enc_type)
shaking_hiddens = self.tp_cln([visible_hiddens, repeat_hiddens])
shaking_hiddens = self.inner_context_cln([shaking_hiddens, inner_context])
shaking_hiddens_list.append(shaking_hiddens)
long_shaking_hiddens = torch.cat(shaking_hiddens_list, dim = 1)
return long_shaking_hiddens
# def handshaking_kernel(self, last_hidden_state):
# '''获取(0,0),(0,1),...,(99,99))对应的序列id
# '''
# btz, _, hdsz = last_hidden_state.shape
# gather_idx = self.gather_idx.repeat(btz, 1, hdsz)
# concat_hidden_states = torch.gather(last_hidden_state, dim=1, index=gather_idx) # [btz, pair_len*2, hdsz]
# concat_hidden_states = concat_hidden_states.reshape(btz, -1, 2, hdsz) # concat方式 [btz, pair_len, 2, hdsz]
# shaking_hiddens = torch.cat(torch.chunk(concat_hidden_states, chunks=2, dim=-2), dim=-1).squeeze(-2) # [btz, pair_len, hdsz*2]
# return shaking_hiddens
class MixUp(nn.Module):
'''mixup方法实现
method: embed, encoder分别表示在embedding和encoder层面做mixup, None表示mix后续处理, hidden表示对隐含层做mixup
'''
def __init__(self, method='encoder', alpha=1.0, layer_mix=None):
super().__init__()
assert method in {'embed', 'encoder', 'hidden', None}
self.method = method
self.alpha = alpha
self.perm_index = None
self.lam = 0
self.layer_mix = layer_mix # 需要mix的隐含层index
def get_perm(self, inputs):
if isinstance(inputs, torch.Tensor):
return inputs[self.perm_index]
elif isinstance(inputs, (list, tuple)):
return [inp[self.perm_index] if isinstance(inp, torch.Tensor) else inp for inp in inputs]
def mix_up(self, output, output1):
if isinstance(output, torch.Tensor):
return self.lam * output + (1.0-self.lam) * output1
elif isinstance(output, (list, tuple)):
output_final = []
for i in range(len(output)):
if output[i] is None: # conditional_emb=None
output_final.append(output[i])
elif (not output[i].requires_grad) and (output[i].dtype in {torch.long, torch.int}):
# 不是embedding形式的
output_final.append(torch.max(output[i], output1[i]))
else:
output_final.append(self.lam * output[i] + (1.0-self.lam) * output1[i])
return output_final
else:
raise ValueError('Illegal model output')
def encode(self, model, inputs):
batch_size = inputs[0].shape[0]
device = inputs[0].device
self.lam = np.random.beta(self.alpha, self.alpha)
self.perm_index = torch.randperm(batch_size).to(device)
if self.method is None:
output = model(inputs)
output1 = self.get_perm(output)
return [output, output1]
elif self.method == 'encoder':
output = model(inputs)
output1 = self.get_perm(output)
output_final = self.mix_up(output, output1)
elif self.method == 'embed':
output = model.apply_embeddings(inputs)
output1 = self.get_perm(output)
output_final = self.mix_up(output, output1)
# Main
output_final = model.apply_main_layers(output_final)
# Final
output_final = model.apply_final_layers(output_final)
elif self.method == 'hidden':
if self.layer_mix is None:
# 这里暂时只考虑encoderLayer, 不考虑decoderLayer和seq2seq模型结构
try:
layer_mix = random.randint(0, len(model.encoderLayer))
except:
warnings.warn('LayerMix random failded')
layer_mix = 0
else:
layer_mix = self.layer_mix
def apply_on_layer_end(l_i, output):
if l_i == layer_mix:
output1 = self.get_perm(output)
return self.mix_up(output, output1)
else:
return output
model.apply_on_layer_end = apply_on_layer_end
output_final = model(inputs)
return output_final
def forward(self, criterion, y_pred, y_true):
'''计算loss
'''
y_true1 = y_true[self.perm_index]
return self.lam * criterion(y_pred, y_true) + (1 - self.lam) * criterion(y_pred, y_true1)
\ No newline at end of file
from ast import arg
from tracemalloc import start
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
class FocalLoss(nn.Module):
'''Multi-class Focal loss implementation'''
def __init__(self, gamma=2, weight=None,ignore_index=-100):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.weight = weight
self.ignore_index=ignore_index
def forward(self, input, target):
"""
input: [N, C]
target: [N, ]
"""
logpt = F.log_softmax(input, dim=1)
pt = torch.exp(logpt)
logpt = (1-pt)**self.gamma * logpt
loss = F.nll_loss(logpt, target, self.weight,ignore_index=self.ignore_index)
return loss
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction='mean',ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=='sum':
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=='mean':
loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
ignore_index=self.ignore_index)
class MultilabelCategoricalCrossentropy(nn.Module):
"""多标签分类的交叉熵
说明:y_true和y_pred的shape一致,y_true的元素非0即1, 1表示对应的类为目标类,0表示对应的类为非目标类。
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax!预测
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解本文。
参考:https://kexue.fm/archives/7359
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, y_pred, y_true):
""" y_true ([Tensor]): [..., num_classes]
y_pred ([Tensor]): [..., num_classes]
"""
y_pred = (1-2*y_true) * y_pred
y_pred_pos = y_pred - (1-y_true) * 1e12
y_pred_neg = y_pred - y_true * 1e12
y_pred_pos = torch.cat([y_pred_pos, torch.zeros_like(y_pred_pos[..., :1])], dim=-1)
y_pred_neg = torch.cat([y_pred_neg, torch.zeros_like(y_pred_neg[..., :1])], dim=-1)
pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
return (pos_loss + neg_loss).mean()
class SparseMultilabelCategoricalCrossentropy(nn.Module):
"""稀疏版多标签分类的交叉熵
说明:
1. y_true.shape=[..., num_positive],
y_pred.shape=[..., num_classes];
2. 请保证y_pred的值域是全体实数,换言之一般情况下y_pred不用加激活函数,尤其是不能加sigmoid或者softmax;
3. 预测阶段则输出y_pred大于0的类;
4. 详情请看:https://kexue.fm/archives/7359 。
"""
def __init__(self, mask_zero=False, epsilon=1e-7, **kwargs):
super().__init__(**kwargs)
self.mask_zero = mask_zero
self.epsilon = epsilon
def forward(self, y_pred, y_true):
zeros = torch.zeros_like(y_pred[..., :1])
y_pred = torch.cat([y_pred, zeros], dim=-1)
if self.mask_zero:
infs = zeros + float('inf')
y_pred = torch.cat([infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
y_pos_1 = torch.cat([y_pos_2, zeros], dim=-1)
if self.mask_zero:
y_pred = torch.cat([-infs, y_pred[..., 1:]], dim=-1)
y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true)
pos_loss = torch.logsumexp(-y_pos_1, dim=-1)
all_loss = torch.logsumexp(y_pred, dim=-1) # a
aux_loss = torch.logsumexp(y_pos_2, dim=-1) - all_loss # b-a
aux_loss = torch.clamp(1 - torch.exp(aux_loss), self.epsilon, 1) # 1-exp(b-a)
neg_loss = all_loss + torch.log(aux_loss) # a + log[1-exp(b-a)]
return pos_loss + neg_loss
class ContrastiveLoss(nn.Module):
"""对比损失:减小正例之间的距离,增大正例和反例之间的距离
公式:labels * distance_matrix.pow(2) + (1-labels)*F.relu(margin-distance_matrix).pow(2)
https://www.sbert.net/docs/package_reference/losses.html
"""
def __init__(self, margin=0.5, size_average=True, online=False):
super(ContrastiveLoss, self).__init__()
self.margin = margin
self.size_average = size_average
self.online = online
def forward(self, distances, labels, pos_id=1, neg_id=0):
if not self.online:
losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
return losses.mean() if self.size_average else losses.sum()
else:
negs = distances[labels == neg_id]
poss = distances[labels == pos_id]
# select hard positive and hard negative pairs
negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]
positive_loss = positive_pairs.pow(2).sum()
negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
return positive_loss + negative_loss
class RDropLoss(nn.Module):
'''R-Drop的Loss实现,官方项目:https://github.com/dropreg/R-Drop
'''
def __init__(self, alpha=4, rank='adjacent'):
super().__init__()
self.alpha = alpha
# 支持两种方式,一种是奇偶相邻排列,一种是上下排列
assert rank in {'adjacent', 'updown'}, "rank kwarg only support 'adjacent' and 'updown' "
self.rank = rank
self.loss_sup = nn.CrossEntropyLoss()
self.loss_rdrop = nn.KLDivLoss(reduction='none')
def forward(self, *args):
'''支持两种方式: 一种是y_pred, y_true, 另一种是y_pred1, y_pred2, y_true
'''
assert len(args) in {2, 3}, 'RDropLoss only support 2 or 3 input args'
# y_pred是1个Tensor
if len(args) == 2:
y_pred, y_true = args
loss_sup = self.loss_sup(y_pred, y_true) # 两个都算
if self.rank == 'adjacent':
y_pred1 = y_pred[1::2]
y_pred2 = y_pred[::2]
elif self.rank == 'updown':
half_btz = y_true.shape[0] // 2
y_pred1 = y_pred[:half_btz]
y_pred2 = y_pred[half_btz:]
# y_pred是两个tensor
else:
y_pred1, y_pred2, y_true = args
loss_sup = self.loss_sup(y_pred1, y_true)
loss_rdrop1 = self.loss_rdrop(F.log_softmax(y_pred1, dim=-1), F.softmax(y_pred2, dim=-1))
loss_rdrop2 = self.loss_rdrop(F.log_softmax(y_pred2, dim=-1), F.softmax(y_pred1, dim=-1))
return loss_sup + torch.mean(loss_rdrop1 + loss_rdrop2) / 4 * self.alpha
class UDALoss(nn.Module):
'''UDALoss,使用时候需要继承一下,因为forward需要使用到global_step和total_steps
https://arxiv.org/abs/1904.12848
'''
def __init__(self, tsa_schedule=None, total_steps=None, start_p=0, end_p=1, return_all_loss=True):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.loss_unsup = nn.KLDivLoss(reduction='batchmean')
self.tsa_schedule = tsa_schedule
self.start = start_p
self.end = end_p
if self.tsa_schedule:
assert self.tsa_schedule in {'linear_schedule', 'exp_schedule', 'log_schedule'}, 'tsa_schedule config illegal'
self.return_all_loss = return_all_loss
def forward(self, y_pred, y_true_sup, global_step, total_steps):
sup_size = y_true_sup.size(0)
unsup_size = (y_pred.size(0) - sup_size) // 2
# 有监督部分, 用交叉熵损失
y_pred_sup = y_pred[:sup_size]
if self.tsa_schedule is None:
loss_sup = self.loss_sup(y_pred_sup, y_true_sup)
else: # 使用tsa来去掉预测概率较高的有监督样本
threshold = self.get_tsa_threshold(self.tsa_schedule, global_step, total_steps, self.start, self.end)
true_prob = torch.gather(F.softmax(y_pred_sup, dim=-1), dim=1, index=y_true_sup[:, None])
sel_rows = true_prob.lt(threshold).sum(dim=-1).gt(0) # 仅保留小于阈值的样本
loss_sup = self.loss_sup(y_pred_sup[sel_rows], y_true_sup[sel_rows]) if sel_rows.sum() > 0 else 0
# 无监督部分,这里用KL散度,也可以用交叉熵
y_true_unsup = y_pred[sup_size:sup_size+unsup_size]
y_true_unsup = F.softmax(y_true_unsup.detach(), dim=-1)
y_pred_unsup = F.log_softmax(y_pred[sup_size+unsup_size:], dim=-1)
loss_unsup = self.loss_unsup(y_pred_unsup, y_true_unsup)
if self.return_all_loss:
return loss_sup + loss_unsup, loss_sup, loss_unsup
else:
return loss_sup + loss_unsup
@ staticmethod
def get_tsa_threshold(schedule, global_step, num_train_steps, start, end):
training_progress = global_step / num_train_steps
if schedule == "linear_schedule":
threshold = training_progress
elif schedule == "exp_schedule":
scale = 5
threshold = math.exp((training_progress - 1) * scale)
elif schedule == "log_schedule":
scale = 5
threshold = 1 - math.exp((-training_progress) * scale)
return threshold * (end - start) + start
class TemporalEnsemblingLoss(nn.Module):
'''TemporalEnsembling的实现,思路是在监督loss的基础上,增加一个mse的一致性损失loss
官方项目:https://github.com/s-laine/tempens
pytorch第三方实现:https://github.com/ferretj/temporal-ensembling
使用的时候,train_dataloader的shffle必须未False
'''
def __init__(self, epochs, max_val=10.0, ramp_up_mult=-5.0, alpha=0.5, max_batch_num=100, hist_device='cpu'):
super().__init__()
self.loss_sup = nn.CrossEntropyLoss()
self.max_epochs = epochs
self.max_val = max_val
self.ramp_up_mult = ramp_up_mult
self.alpha = alpha
self.max_batch_num = max_batch_num # 设置未None表示记录全部数据历史,数据量大时耗资源
self.hist_unsup = [] # 历史无监督logit
self.hist_sup = [] # 历史监督信息logit
self.hist_device = hist_device
self.hist_input_y = [] # 历史监督标签y
assert (self.alpha >= 0) & (self.alpha < 1) # 等于1的时候upata写分母为0
def forward(self, y_pred_sup, y_pred_unsup, y_true_sup, epoch, bti):
self.same_batch_check(y_pred_sup, y_pred_unsup, y_true_sup, bti)
if (self.max_batch_num is None) or (bti < self.max_batch_num):
self.init_hist(bti, y_pred_sup, y_pred_unsup) # 初始化历史
sup_ratio = float(len(y_pred_sup)) / (len(y_pred_sup) + len(y_pred_unsup)) # 监督样本的比例
w = self.weight_schedule(epoch, sup_ratio)
sup_loss, unsup_loss = self.temporal_loss(y_pred_sup, y_pred_unsup, y_true_sup, bti)
# 更新
self.hist_unsup[bti] = self.update(self.hist_unsup[bti], y_pred_unsup.detach(), epoch)
self.hist_sup[bti] = self.update(self.hist_sup[bti], y_pred_sup.detach(), epoch)
# if bti == 0: $ 用于检查每个epoch数据顺序是否一致
# print(w, sup_loss.item(), w * unsup_loss.item())
# print(y_true_sup)
return sup_loss + w * unsup_loss, sup_loss, w * unsup_loss
else:
return self.loss_sup(y_pred_sup, y_true_sup)
def same_batch_check(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
'''检测数据的前几个batch必须是一致的, 这里写死是10个
'''
if bti >= 10:
return
if bti >= len(self.hist_input_y):
self.hist_input_y.append(y_true_sup.to(self.hist_device))
else: # 检测
err_msg = 'TemporalEnsemblingLoss requests the same sort dataloader, you may need to set train_dataloader shuffle=False'
assert self.hist_input_y[bti].equal(y_true_sup.to(self.hist_device)), err_msg
def update(self, hist, y_pred, epoch):
'''更新历史logit,利用alpha门控来控制比例
'''
Z = self.alpha * hist.to(y_pred) + (1. -self.alpha) * y_pred
output = Z * (1. / (1. - self.alpha ** (epoch + 1)))
return output.to(self.hist_device)
def weight_schedule(self, epoch, sup_ratio):
max_val = self.max_val * sup_ratio
if epoch == 0:
return 0.
elif epoch >= self.max_epochs:
return max_val
return max_val * np.exp(self.ramp_up_mult * (1. - float(epoch) / self.max_epochs) ** 2)
def temporal_loss(self, y_pred_sup, y_pred_unsup, y_true_sup, bti):
# MSE between current and temporal outputs
def mse_loss(out1, out2):
quad_diff = torch.sum((F.softmax(out1, dim=1) - F.softmax(out2, dim=1)) ** 2)
return quad_diff / out1.data.nelement()
sup_loss = self.loss_sup(y_pred_sup, y_true_sup)
# 原来实现是sup和unsup作为一个tensor,整体计算的,这里由于是拆分成两个tensor,因此分开算
unsup_loss = mse_loss(y_pred_unsup, self.hist_unsup[bti].to(y_pred_unsup))
unsup_loss += mse_loss(y_pred_sup, self.hist_sup[bti].to(y_pred_sup))
return sup_loss, unsup_loss
def init_hist(self, bti, y_pred_sup, y_pred_unsup):
if bti >= len(self.hist_sup):
self.hist_sup.append(torch.zeros_like(y_pred_sup).to(self.hist_device))
self.hist_unsup.append(torch.zeros_like(y_pred_unsup).to(self.hist_device))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment