Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
bert4torch_pytorch
Commits
66a1d0d0
Commit
66a1d0d0
authored
Aug 22, 2023
by
yangzhong
Browse files
提交初版bert4torch project
parents
Pipeline
#519
canceled with stages
Changes
160
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5135 additions
and
0 deletions
+5135
-0
build/lib/bert4torch/models.py
build/lib/bert4torch/models.py
+1983
-0
build/lib/bert4torch/optimizers.py
build/lib/bert4torch/optimizers.py
+76
-0
build/lib/bert4torch/snippets.py
build/lib/bert4torch/snippets.py
+1190
-0
build/lib/bert4torch/tokenizers.py
build/lib/bert4torch/tokenizers.py
+872
-0
dist/bert4torch-0.1.9-py3.9.egg
dist/bert4torch-0.1.9-py3.9.egg
+0
-0
examples/Performance.md
examples/Performance.md
+102
-0
examples/README.md
examples/README.md
+119
-0
examples/basic/basic_extract_features.py
examples/basic/basic_extract_features.py
+39
-0
examples/basic/basic_gibbs_sampling_via_mlm.py
examples/basic/basic_gibbs_sampling_via_mlm.py
+55
-0
examples/basic/basic_language_model_CDial_GPT.py
examples/basic/basic_language_model_CDial_GPT.py
+58
-0
examples/basic/basic_language_model_GAU_alpha.py
examples/basic/basic_language_model_GAU_alpha.py
+33
-0
examples/basic/basic_language_model_cpm_lm.py
examples/basic/basic_language_model_cpm_lm.py
+121
-0
examples/basic/basic_language_model_gpt2_ml.py
examples/basic/basic_language_model_gpt2_ml.py
+57
-0
examples/basic/basic_language_model_nezha_gen_gpt.py
examples/basic/basic_language_model_nezha_gen_gpt.py
+60
-0
examples/basic/basic_language_model_nezha_gpt_dialog.py
examples/basic/basic_language_model_nezha_gpt_dialog.py
+54
-0
examples/basic/basic_language_model_simbert.py
examples/basic/basic_language_model_simbert.py
+130
-0
examples/basic/basic_language_model_t5.py
examples/basic/basic_language_model_t5.py
+50
-0
examples/basic/basic_language_model_transformer_xl.py
examples/basic/basic_language_model_transformer_xl.py
+41
-0
examples/basic/basic_language_model_xlnet.py
examples/basic/basic_language_model_xlnet.py
+28
-0
examples/basic/basic_make_uncased_model_cased.py
examples/basic/basic_make_uncased_model_cased.py
+67
-0
No files found.
build/lib/bert4torch/models.py
0 → 100644
View file @
66a1d0d0
import
torch
import
torch.nn
as
nn
import
copy
import
json
import
re
from
bert4torch.layers
import
LayerNorm
,
BertEmbeddings
,
BertLayer
,
Identity
,
T5Layer
,
GatedAttentionUnit
,
XlnetLayer
from
bert4torch.layers
import
AdaptiveEmbedding
,
XlnetPositionsEncoding
from
bert4torch.snippets
import
metric_mapping
,
search_layer
,
insert_arguments
,
delete_arguments
,
get_kw
from
bert4torch.snippets
import
ProgbarLogger
,
EarlyStopping
,
FGM
,
PGD
,
VAT
,
IterDataset
,
take_along_dim
from
bert4torch.activations
import
get_activation
import
warnings
class
BaseModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
BaseModel
,
self
).
__init__
()
# 这里主要是为了外面调用用到
self
.
global_step
,
self
.
local_step
,
self
.
total_steps
,
self
.
epoch
,
self
.
train_dataloader
=
0
,
0
,
0
,
0
,
None
self
.
callbacks
=
[]
def
compile
(
self
,
loss
,
optimizer
,
scheduler
=
None
,
max_grad_norm
=
None
,
use_amp
=
False
,
metrics
=
None
,
adversarial_train
=
{
'name'
:
''
}):
'''定义loss, optimizer, metrics, 是否在计算loss前reshape
loss: loss
optimizer: 优化器
scheduler: scheduler
max_grad_norm: 是否使用梯度裁剪, 默认不启用
use_amp: 是否使用混合精度,默认不启用
metrics: 训练过程中需要打印的指标, loss相关指标默认会打印, 目前支持accuracy
'''
self
.
criterion
=
loss
self
.
optimizer
=
optimizer
self
.
scheduler
=
scheduler
self
.
max_grad_norm
=
max_grad_norm
self
.
use_amp
=
use_amp
if
use_amp
:
assert
adversarial_train
[
'name'
]
not
in
{
'vat'
,
'gradient_penalty'
},
'Amp and adversarial_train both run is not supported in current version'
from
torch.cuda.amp
import
autocast
self
.
autocast
=
autocast
self
.
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
if
metrics
is
None
:
metrics
=
[]
self
.
metrics
=
[
'loss'
]
+
[
i
for
i
in
metrics
if
i
!=
'loss'
]
# 对抗训练
self
.
adversarial
=
adversarial_train
self
.
adversarial_initialize
()
def
adversarial_initialize
(
self
):
'''对抗训练初始化
'''
assert
self
.
adversarial
[
'name'
]
in
{
''
,
'fgm'
,
'pgd'
,
'vat'
,
'gradient_penalty'
},
'adversarial_train support fgm, pgd, vat and gradient_penalty mode'
self
.
adversarial
[
'epsilon'
]
=
self
.
adversarial
.
get
(
'epsilon'
,
1.0
)
self
.
adversarial
[
'emb_name'
]
=
self
.
adversarial
.
get
(
'emb_name'
,
'word_embeddings'
)
if
self
.
adversarial
[
'name'
]
==
'fgm'
:
self
.
ad_train
=
FGM
(
self
)
elif
self
.
adversarial
[
'name'
]
==
'pgd'
:
self
.
adversarial
[
'K'
]
=
self
.
adversarial
.
get
(
'K'
,
3
)
# 步数
self
.
adversarial
[
'alpha'
]
=
self
.
adversarial
.
get
(
'alpha'
,
0.3
)
# 学习率
self
.
ad_train
=
PGD
(
self
)
elif
self
.
adversarial
[
'name'
]
==
'gradient_penalty'
:
pass
elif
self
.
adversarial
[
'name'
]
==
'vat'
:
self
.
adversarial
[
'K'
]
=
self
.
adversarial
.
get
(
'K'
,
3
)
self
.
adversarial
[
'noise_var'
]
=
self
.
adversarial
.
get
(
'noise_var'
,
1e-5
)
# 噪声的方差
self
.
adversarial
[
'noise_gamma'
]
=
self
.
adversarial
.
get
(
'noise_gamma'
,
1e-6
)
# eps
self
.
adversarial
[
'adv_step_size'
]
=
self
.
adversarial
.
get
(
'adv_step_size'
,
1e-3
)
# 学习率
self
.
adversarial
[
'adv_alpha'
]
=
self
.
adversarial
.
get
(
'adv_alpha'
,
1
)
# 对抗loss的权重
self
.
adversarial
[
'norm_type'
]
=
self
.
adversarial
.
get
(
'norm_type'
,
'l2'
)
# 归一化方式
self
.
ad_train
=
VAT
(
self
,
**
self
.
adversarial
)
def
adversarial_training
(
self
,
train_X
,
train_y
,
output
,
loss
,
loss_detail
,
grad_accumulation_steps
):
'''对抗训练
'''
if
self
.
adversarial
[
'name'
]
==
'fgm'
:
self
.
ad_train
.
attack
(
**
self
.
adversarial
)
# embedding被修改了
output
,
loss
,
loss_detail
=
self
.
train_step
(
train_X
,
train_y
,
grad_accumulation_steps
)
loss
.
backward
()
# 反向传播,在正常的grad基础上,累加对抗训练的梯度
# 恢复Embedding的参数, 因为要在正常的embedding上更新参数,而不是增加了对抗扰动后的embedding上更新参数~
self
.
ad_train
.
restore
(
**
self
.
adversarial
)
elif
self
.
adversarial
[
'name'
]
==
'pgd'
:
self
.
ad_train
.
backup_grad
()
# 备份梯度
for
t
in
range
(
self
.
adversarial
[
'K'
]):
# 在embedding上添加对抗扰动, first attack时备份param.data
self
.
ad_train
.
attack
(
**
self
.
adversarial
,
is_first_attack
=
(
t
==
0
))
if
t
!=
self
.
adversarial
[
'K'
]
-
1
:
self
.
optimizer
.
zero_grad
()
# 为了累积扰动而不是梯度
else
:
self
.
ad_train
.
restore_grad
()
# 恢复正常的grad
output
,
loss
,
loss_detail
=
self
.
train_step
(
train_X
,
train_y
,
grad_accumulation_steps
)
loss
.
backward
()
# 反向传播,在正常的grad基础上,累加对抗训练的梯度
self
.
ad_train
.
restore
(
**
self
.
adversarial
)
# 恢复embedding参数
# 梯度惩罚
elif
self
.
adversarial
[
'name'
]
==
'gradient_penalty'
:
para
=
search_layer
(
self
,
self
.
adversarial
[
'emb_name'
],
retrun_first
=
True
)
gp
=
(
para
.
grad
**
2
).
sum
()
loss
+=
0.5
*
gp
*
self
.
adversarial
[
'epsilon'
]
loss
.
backward
()
# 虚拟对抗训练
elif
self
.
adversarial
[
'name'
]
==
'vat'
:
logit
=
output
[
0
]
if
isinstance
(
output
,
(
list
,
tuple
))
else
output
adv_loss
=
self
.
ad_train
.
virtual_adversarial_training
(
train_X
,
logit
)
loss_detail
.
update
({
'loss_sup'
:
loss
.
item
(),
'loss_unsup'
:
adv_loss
})
loss
+=
(
adv_loss
if
adv_loss
else
0
)
loss
.
backward
()
return
loss
,
loss_detail
def
train_step
(
self
,
train_X
,
train_y
,
grad_accumulation_steps
):
'''forward并返回loss
'''
def
args_segmentate
(
train_X
):
'''参数是否展开
'''
if
isinstance
(
train_X
,
torch
.
Tensor
):
# tensor不展开
pass
elif
isinstance
(
self
,
(
BaseModelDP
,
BaseModelDDP
)):
if
self
.
module
.
forward
.
__code__
.
co_argcount
>=
3
:
return
True
elif
self
.
forward
.
__code__
.
co_argcount
>=
3
:
return
True
return
False
if
self
.
use_amp
:
with
self
.
autocast
():
output
=
self
.
forward
(
*
train_X
)
if
args_segmentate
(
train_X
)
else
self
.
forward
(
train_X
)
loss_detail
=
self
.
criterion
(
output
,
train_y
)
else
:
output
=
self
.
forward
(
*
train_X
)
if
args_segmentate
(
train_X
)
else
self
.
forward
(
train_X
)
loss_detail
=
self
.
criterion
(
output
,
train_y
)
if
isinstance
(
loss_detail
,
torch
.
Tensor
):
loss
=
loss_detail
loss_detail
=
{}
elif
isinstance
(
loss_detail
,
dict
):
loss
=
loss_detail
[
'loss'
]
# 还存在其他loss,仅用于打印
del
loss_detail
[
'loss'
]
elif
isinstance
(
loss_detail
,
(
tuple
,
list
)):
loss
=
loss_detail
[
0
]
loss_detail
=
{
f
'loss
{
i
}
'
:
v
for
i
,
v
in
enumerate
(
loss_detail
[
1
:],
start
=
1
)}
else
:
raise
ValueError
(
'Return loss only support Tensor/dict/tuple/list format'
)
# 梯度累积
loss
=
loss
/
grad_accumulation_steps
if
grad_accumulation_steps
>
1
else
loss
return
output
,
loss
,
loss_detail
def
callback_fun
(
self
,
mode
,
logs
=
{}):
'''统一调用callback, 方便一些判断条件的触发
'''
# 如果是分布式DDP训练,则仅masker_rank可以callback
if
isinstance
(
self
,
BaseModelDDP
)
and
self
.
master_rank
!=
torch
.
distributed
.
get_rank
():
return
if
mode
==
'train_begin'
:
for
callback
in
self
.
callbacks
:
callback
.
on_train_begin
()
elif
mode
==
'epoch_begin'
:
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_begin
(
self
.
global_step
,
self
.
epoch
,
logs
)
elif
mode
==
'batch_begin'
:
for
callback
in
self
.
callbacks
:
callback
.
on_batch_begin
(
self
.
global_step
,
self
.
local_step
,
logs
)
elif
mode
==
'batch_end'
:
for
callback
in
self
.
callbacks
:
callback
.
on_batch_end
(
self
.
global_step
,
self
.
local_step
,
logs
)
elif
mode
==
'epoch_end'
:
for
callback
in
self
.
callbacks
:
callback
.
on_epoch_end
(
self
.
global_step
,
self
.
epoch
,
logs
)
elif
mode
==
'train_end'
:
for
callback
in
self
.
callbacks
:
callback
.
on_train_end
()
elif
mode
==
'dataloader_end'
:
for
callback
in
self
.
callbacks
:
callback
.
on_dataloader_end
()
def
fit
(
self
,
train_dataloader
,
steps_per_epoch
=
None
,
epochs
=
1
,
grad_accumulation_steps
=
1
,
callbacks
=
[]):
if
isinstance
(
train_dataloader
.
dataset
,
IterDataset
):
assert
steps_per_epoch
is
not
None
,
'IterDataset should specify steps_per_epoch'
steps_per_epoch
=
len
(
train_dataloader
)
if
steps_per_epoch
is
None
else
steps_per_epoch
self
.
total_steps
=
steps_per_epoch
*
epochs
self
.
global_step
=
0
self
.
train_dataloader
=
train_dataloader
# 设置为成员变量,可由外部的callbacks进行修改
train_dataloader_iter
=
iter
(
self
.
train_dataloader
)
# 循环epoch时不重生成
self
.
callbacks
=
[
ProgbarLogger
(
epochs
,
steps_per_epoch
,
self
.
metrics
)]
+
(
callbacks
if
isinstance
(
callbacks
,
(
list
,
tuple
))
else
[
callbacks
])
self
.
callback_fun
(
'train_begin'
)
# epoch:当前epoch
# global_step:当前全局训练步数
# local_step: 当前epoch内的训练步数,不同epoch中相同local_step对应的batch数据不一定相同,在steps_per_epoch=None时相同
# bti:在dataloader中的index,不同epoch中相同的bti对应的batch数据一般相同,除非重新生成dataloader
self
.
bti
=
0
for
epoch
in
range
(
epochs
):
self
.
epoch
=
epoch
self
.
callback_fun
(
'epoch_begin'
)
for
local_step
in
range
(
steps_per_epoch
):
self
.
local_step
=
local_step
# 循环dataloader, 不要试用itertools的cycle,遇到过变量不释放的问题
try
:
batch
=
next
(
train_dataloader_iter
)
except
StopIteration
:
self
.
callback_fun
(
'dataloader_end'
)
# 适用于数据量较大时,动态读取文件并重新生成dataloader的情况,如预训练
train_dataloader_iter
=
iter
(
self
.
train_dataloader
)
# shuffle=True时候,其实顺序也重新生成了
self
.
bti
=
0
batch
=
next
(
train_dataloader_iter
)
train_X
,
train_y
=
batch
# 取btz,最多允许嵌套两层,即((token_ids1, mask1), (token_ids2, mask2))
if
isinstance
(
train_X
,
(
list
,
tuple
)):
if
isinstance
(
train_X
[
0
],
(
list
,
tuple
)):
btz
=
train_X
[
0
][
0
].
size
(
0
)
else
:
btz
=
train_X
[
0
].
size
(
0
)
elif
isinstance
(
train_X
,
torch
.
Tensor
):
btz
=
train_X
.
size
(
0
)
else
:
raise
ValueError
(
'Input only support [list, tuple, tensor]'
)
logs
=
{
'batch'
:
self
.
local_step
,
'size'
:
btz
}
self
.
callback_fun
(
'batch_begin'
,
logs
)
self
.
train
()
# 设置为train模式
# 入参个数判断,如果入参>=3表示是多个入参,如果=2则表示是一个入参
output
,
loss
,
loss_detail
=
self
.
train_step
(
train_X
,
train_y
,
grad_accumulation_steps
)
retain_graph
=
True
if
self
.
adversarial
[
'name'
]
in
{
'gradient_penalty'
,
'vat'
}
else
False
if
self
.
use_amp
:
# 混合精度
scale_before_step
=
self
.
scaler
.
get_scale
()
self
.
scaler
.
scale
(
loss
).
backward
(
retain_graph
=
retain_graph
)
else
:
loss
.
backward
(
retain_graph
=
retain_graph
)
# 对抗训练
loss
,
loss_detail
=
self
.
adversarial_training
(
train_X
,
train_y
,
output
,
loss
,
loss_detail
,
grad_accumulation_steps
)
# 参数更新, 真实的参数更新次数要除以grad_accumulation_steps,注意调整总的训练步数
if
(
self
.
global_step
+
1
)
%
grad_accumulation_steps
==
0
:
skip_scheduler
=
False
# 混合精度
if
self
.
use_amp
:
self
.
scaler
.
unscale_
(
self
.
optimizer
)
if
self
.
max_grad_norm
is
not
None
:
# 梯度裁剪
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
parameters
(),
self
.
max_grad_norm
)
self
.
scaler
.
step
(
self
.
optimizer
)
self
.
scaler
.
update
()
skip_scheduler
=
self
.
scaler
.
get_scale
()
!=
scale_before_step
else
:
if
self
.
max_grad_norm
is
not
None
:
# 梯度裁剪
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
parameters
(),
self
.
max_grad_norm
)
self
.
optimizer
.
step
()
self
.
optimizer
.
zero_grad
()
# 清梯度
if
(
self
.
scheduler
is
not
None
)
and
not
skip_scheduler
:
self
.
scheduler
.
step
()
# 添加log打印
logs
.
update
({
'loss'
:
loss
.
item
()})
logs_loss_detail
=
{
k
:
v
.
item
()
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
loss_detail
.
items
()}
logs
.
update
(
logs_loss_detail
)
if
self
.
global_step
==
0
:
self
.
callbacks
[
0
].
add_metrics
(
list
(
logs_loss_detail
.
keys
()),
add_position
=
1
)
for
metric
in
self
.
metrics
:
tmp
=
metric_mapping
(
metric
,
output
,
train_y
)
# 内置的一些accuracy指标
if
tmp
is
not
None
:
logs
[
metric
]
=
tmp
self
.
callback_fun
(
'batch_end'
,
logs
)
self
.
bti
+=
1
self
.
global_step
+=
1
self
.
callback_fun
(
'epoch_end'
,
logs
)
# earlystop策略
callback_tmp
=
[
callback_tmp
for
callback_tmp
in
self
.
callbacks
if
isinstance
(
callback_tmp
,
EarlyStopping
)]
if
callback_tmp
and
callback_tmp
[
0
].
stopped_epoch
>
0
:
break
self
.
callback_fun
(
'train_end'
,
logs
)
@
torch
.
no_grad
()
def
predict
(
self
,
input_tensor_list
,
return_all
=
None
):
self
.
eval
()
if
self
.
forward
.
__code__
.
co_argcount
>=
3
:
output
=
self
.
forward
(
*
input_tensor_list
)
else
:
output
=
self
.
forward
(
input_tensor_list
)
if
return_all
is
None
:
return
output
elif
isinstance
(
output
,
(
tuple
,
list
))
and
isinstance
(
return_all
,
int
)
and
return_all
<
len
(
output
):
return
output
[
return_all
]
else
:
raise
ValueError
(
'Return format error'
)
def
load_weights
(
self
,
load_path
,
strict
=
True
,
prefix
=
None
):
state_dict
=
torch
.
load
(
load_path
,
map_location
=
'cpu'
)
if
prefix
is
None
:
self
.
load_state_dict
(
state_dict
,
strict
=
strict
)
else
:
# 加载save_weights中to_raw_format=True的情形
eval_str
=
'self.variable_mapping()'
if
prefix
==
''
else
f
'self.
{
prefix
}
.variable_mapping()'
mapping
=
{
v
:
k
for
k
,
v
in
eval
(
eval_str
).
items
()}
mapping
=
mapping
if
prefix
==
''
else
{
k
:
f
'
{
prefix
}
.
{
v
}
'
for
k
,
v
in
mapping
.
items
()}
state_dict_raw
=
{}
for
k
,
v
in
state_dict
.
items
():
k
=
mapping
.
get
(
k
,
k
)
state_dict_raw
[
k
]
=
v
self
.
load_state_dict
(
state_dict_raw
,
strict
=
strict
)
def
save_weights
(
self
,
save_path
,
prefix
=
None
):
if
prefix
is
None
:
torch
.
save
(
self
.
state_dict
(),
save_path
)
else
:
# 按照variable_mapping()中原始的key保存,方便其他官方代码加载模型
eval_str
=
'self.variable_mapping()'
if
prefix
==
''
else
f
'self.
{
prefix
}
.variable_mapping()'
mapping
=
eval
(
eval_str
)
mapping
=
mapping
if
prefix
==
''
else
{
f
'
{
prefix
}
.
{
k
}
'
:
v
for
k
,
v
in
mapping
.
items
()}
state_dict_raw
=
{}
for
k
,
v
in
self
.
state_dict
().
items
():
k
=
mapping
.
get
(
k
,
k
)
state_dict_raw
[
k
]
=
v
torch
.
save
(
state_dict_raw
,
save_path
)
class
BaseModelDP
(
BaseModel
,
nn
.
DataParallel
):
'''DataParallel模式使用多gpu的方法
'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
nn
.
DataParallel
.
__init__
(
self
,
*
args
,
**
kwargs
)
class
BaseModelDDP
(
BaseModel
,
nn
.
parallel
.
DistributedDataParallel
):
'''DistributedDataParallel模式使用多gpu的方法
'''
def
__init__
(
self
,
*
args
,
master_rank
=
0
,
**
kwargs
):
self
.
master_rank
=
master_rank
# 用于记录打印条的rank
nn
.
parallel
.
DistributedDataParallel
.
__init__
(
self
,
*
args
,
**
kwargs
)
class
BERT_BASE
(
BaseModel
):
"""模型基类
"""
def
__init__
(
self
,
vocab_size
,
# 词表大小
hidden_size
,
# 编码维度
num_hidden_layers
,
# Transformer总层数
num_attention_heads
,
# Attention的头数
intermediate_size
,
# FeedForward的隐层维度
hidden_act
,
# FeedForward隐层的激活函数
dropout_rate
=
None
,
# Dropout比例
attention_probs_dropout_prob
=
None
,
# Attention矩阵的Dropout比例
embedding_size
=
None
,
# 指定embedding_size, 不指定则使用config文件的参数
attention_head_size
=
None
,
# Attention中V的head_size
attention_key_size
=
None
,
# Attention中Q,K的head_size
initializer_range
=
0.02
,
# 权重初始化方差
sequence_length
=
None
,
# 是否固定序列长度
keep_tokens
=
None
,
# 要保留的词ID列表
compound_tokens
=
None
,
# 扩展Embedding
residual_attention_scores
=
False
,
# Attention矩阵加残差
ignore_invalid_weights
=
False
,
# 允许跳过不存在的权重
keep_hidden_layers
=
None
,
# 保留的hidden_layer层的id
hierarchical_position
=
None
,
# 是否层次分解位置编码
**
kwargs
):
super
(
BERT_BASE
,
self
).
__init__
()
if
keep_tokens
is
not
None
:
vocab_size
=
len
(
keep_tokens
)
if
compound_tokens
is
not
None
:
vocab_size
+=
len
(
compound_tokens
)
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
attention_head_size
=
attention_head_size
or
self
.
hidden_size
//
self
.
num_attention_heads
self
.
attention_key_size
=
attention_key_size
or
self
.
attention_head_size
self
.
intermediate_size
=
intermediate_size
self
.
dropout_rate
=
dropout_rate
or
0
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
or
0
self
.
hidden_act
=
hidden_act
self
.
embedding_size
=
embedding_size
or
hidden_size
self
.
initializer_range
=
initializer_range
self
.
sequence_length
=
sequence_length
self
.
keep_tokens
=
keep_tokens
self
.
compound_tokens
=
compound_tokens
self
.
attention_bias
=
None
self
.
position_bias
=
None
self
.
attention_scores
=
None
self
.
residual_attention_scores
=
residual_attention_scores
self
.
ignore_invalid_weights
=
ignore_invalid_weights
self
.
keep_hidden_layers
=
set
(
range
(
num_hidden_layers
))
if
keep_hidden_layers
is
None
else
set
(
keep_hidden_layers
)
self
.
hierarchical_position
=
hierarchical_position
def
build
(
self
,
attention_caches
=
None
,
layer_norm_cond
=
None
,
layer_norm_cond_hidden_size
=
None
,
layer_norm_cond_hidden_act
=
None
,
additional_input_layers
=
None
,
**
kwargs
):
"""模型构建函数
attention_caches: 为Attention的K,V的缓存序列字典,格式为{Attention层名: [K缓存, V缓存]};
layer_norm_*系列参数: 实现Conditional Layer Normalization时使用,用来实现以“固定长度向量”为条件的条件Bert。
"""
# additional_input
# if additional_input_layers is not None:
# if not isinstance(additional_input_layers, list):
# self.additional_input_layers = [additional_input_layers]
# else:
# self.additional_input_layers = additional_input_layers
# Other
self
.
attention_caches
=
attention_caches
or
{}
# self.layer_norm_conds = [
# layer_norm_cond,
# layer_norm_cond_hidden_size,
# layer_norm_cond_hidden_act or 'linear',
# ]
self
.
output_all_encoded_layers
=
kwargs
.
get
(
'output_all_encoded_layers'
,
False
)
def
forward
(
self
,
inputs
):
"""定义模型的执行流程
"""
# Embedding
outputs
=
self
.
apply_embeddings
(
inputs
)
# Main
outputs
=
self
.
apply_main_layers
(
outputs
)
# Final
outputs
=
self
.
apply_final_layers
(
outputs
)
return
outputs
def
init_model_weights
(
self
,
module
):
""" 初始化权重
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
))
and
(
module
.
weight
.
requires_grad
):
# bert参数初始化, tf版本在linear和Embedding层使用的是截断正太分布, pytorch没有实现该函数,
# 此种初始化对于加载预训练模型后进行finetune没有任何影响,
# cf https://github.com/pytorch/pytorch/pull/5617
# 固定的相对位置编码如Sinusoidal无需初始化
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
initializer_range
)
elif
isinstance
(
module
,
LayerNorm
):
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
.
requires_grad
:
# T5等模型使用的是rmsnorm
module
.
bias
.
data
.
zero_
()
if
hasattr
(
module
,
'weight'
)
and
module
.
weight
.
requires_grad
:
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
(
module
.
bias
is
not
None
)
and
(
module
.
bias
.
requires_grad
):
module
.
bias
.
data
.
zero_
()
def
variable_mapping
(
self
):
"""构建pytorch层与checkpoint的变量名之间的映射表
"""
return
{}
def
load_load_variable
(
self
):
raise
NotImplementedError
def
load_embeddings
(
self
,
embeddings
):
"""根据keep_tokens和compound_tokens对embedding进行修改
"""
if
self
.
keep_tokens
is
not
None
:
embeddings
=
embeddings
[
self
.
keep_tokens
]
if
self
.
compound_tokens
is
not
None
:
ext_embeddings
=
[]
for
item
in
self
.
compound_tokens
:
try
:
ext_embeddings
.
append
(
torch
.
mean
(
embeddings
[
item
],
0
)
*
torch
.
ones_like
(
embeddings
[
item
]))
except
IndexError
:
ext_embeddings
.
append
(
torch
.
mean
(
embeddings
,
0
,
keepdim
=
True
))
warnings
.
warn
(
f
'Initialize ext_embeddings from compound_tokens not in embedding index'
)
embeddings
=
torch
.
cat
([
embeddings
]
+
ext_embeddings
,
0
)
return
embeddings
def
load_pos_embeddings
(
self
,
embeddings
):
"""根据hierarchical_position对pos_embedding进行修改
"""
if
self
.
hierarchical_position
is
not
None
:
alpha
=
0.4
if
self
.
hierarchical_position
is
True
else
self
.
hierarchical_position
embeddings
=
embeddings
-
alpha
*
embeddings
[:
1
]
embeddings
=
embeddings
/
(
1
-
alpha
)
position_index
=
torch
.
arange
(
self
.
max_position
)[:,
None
]
# 为兼容低版本pytorch没有take_along_dim
embeddings_x
=
take_along_dim
(
embeddings
,
torch
.
div
(
position_index
,
embeddings
.
size
(
0
),
rounding_mode
=
'trunc'
),
dim
=
0
)
embeddings_y
=
take_along_dim
(
embeddings
,
position_index
%
embeddings
.
size
(
0
),
dim
=
0
)
embeddings
=
alpha
*
embeddings_x
+
(
1
-
alpha
)
*
embeddings_y
return
embeddings
def
load_weights_from_pytorch_checkpoint
(
self
,
checkpoint
,
mapping
=
None
):
"""根据mapping从checkpoint加载权重
"""
file_state_dict
=
torch
.
load
(
checkpoint
,
map_location
=
'cpu'
)
# 加载模型文件
mapping
=
mapping
or
self
.
variable_mapping
()
parameters_set
=
set
([
i
[
0
]
for
i
in
self
.
named_parameters
()])
# 可更新的变量
# 如果模型文件和模型结构中同时存在,且不在预设的mapping中,则更新mapping
# 主要是如为了在外部继承BERT后有其他layer,也能自动从checkpoint中加载进来
for
layer_name
in
parameters_set
:
if
(
layer_name
in
file_state_dict
)
and
(
layer_name
not
in
mapping
):
mapping
.
update
({
layer_name
:
layer_name
})
state_dict_new
=
{}
for
new_key
,
old_key
in
mapping
.
items
():
if
new_key
not
in
self
.
state_dict
():
continue
elif
old_key
in
file_state_dict
:
# mapping中包含,且模型结构中有
state_dict_new
[
new_key
]
=
self
.
load_variable
(
file_state_dict
,
old_key
)
elif
(
old_key
not
in
file_state_dict
)
and
(
not
self
.
ignore_invalid_weights
):
# mapping中包含,但模型文件中没有
print
(
f
'[WARNIMG]
{
old_key
}
not found in pretrain models'
)
if
new_key
in
parameters_set
:
parameters_set
.
remove
(
new_key
)
# 未能加载预训练权重的Parameter
if
not
self
.
ignore_invalid_weights
:
for
key
in
parameters_set
:
print
(
f
'[WARNIMG] Parameter
{
key
}
not loaded from pretrain models'
)
del
file_state_dict
# 将ckpt的权重load到模型结构中
self
.
load_state_dict
(
state_dict_new
,
strict
=
False
)
# def get_inputs(self):
# pass
# def set_inputs(self, inputs, additional_input_layers=None):
# """设置input和inputs属性
# """
# pass
def
apply_embeddings
(
self
,
inputs
):
raise
NotImplementedError
def
apply_main_layers
(
self
,
inputs
):
raise
NotImplementedError
def
apply_final_layers
(
self
,
inputs
):
raise
NotImplementedError
def
apply_on_layer_begin
(
self
,
l_i
,
inputs
):
'''新增对layer block输入进行操作的函数
'''
return
inputs
def
apply_on_layer_end
(
self
,
l_i
,
inputs
):
'''新增对layer block输出进行操作的函数
'''
return
inputs
def
compute_attention_bias
(
self
,
inputs
=
None
):
"""定义每一层的Attention Bias
"""
return
self
.
attention_bias
def
compute_position_bias
(
self
,
inputs
=
None
):
"""定义每一层的Position Bias(一般相对位置编码用)
"""
return
self
.
position_bias
def
set_outputs
(
self
,
outputs
):
"""设置output和oututs属性
"""
if
not
isinstance
(
outputs
,
list
):
outputs
=
[
outputs
]
outputs
=
outputs
[:]
self
.
outputs
=
outputs
if
len
(
outputs
)
>
1
:
self
.
output
=
outputs
else
:
self
.
output
=
outputs
[
0
]
class
LM_Mask
(
object
):
"""定义下三角Attention Mask(语言模型用)
"""
def
compute_attention_bias
(
self
,
inputs
=
None
):
"""通过idxs序列的比较来得到对应的mask
"""
seq_len
=
inputs
[
0
].
shape
[
1
]
attention_bias
=
torch
.
tril
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
torch
.
long
,
device
=
inputs
[
0
].
device
),
diagonal
=
0
)
self
.
attention_bias
=
attention_bias
.
unsqueeze
(
0
).
unsqueeze
(
1
)
return
self
.
attention_bias
def
extend_with_language_model
(
InputModel
):
"""添加下三角的Attention Mask(语言模型用)
"""
class
LanguageModel
(
LM_Mask
,
InputModel
):
"""带下三角Attention Mask的派生模型
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
[
'with_mlm'
]
=
kwargs
.
get
(
'with_mlm'
)
or
True
super
(
LanguageModel
,
self
).
__init__
(
*
args
,
**
kwargs
)
return
LanguageModel
class
UniLM_Mask
(
object
):
"""定义UniLM的Attention Mask(Seq2Seq模型用)
其中source和target的分区,由segment_ids来表示。
UniLM: https://arxiv.org/abs/1905.03197
"""
def
compute_attention_bias
(
self
,
inputs
=
None
):
"""通过idxs序列的比较来得到对应的mask
"""
segment_ids
=
inputs
[
1
]
attention_bias
=
torch
.
cumsum
(
segment_ids
,
dim
=
1
)
attention_bias
=
(
attention_bias
.
unsqueeze
(
1
))
<=
(
attention_bias
.
unsqueeze
(
2
))
self
.
attention_bias
=
attention_bias
.
unsqueeze
(
1
).
long
()
return
self
.
attention_bias
def
extend_with_unified_language_model
(
InputModel
):
"""添加UniLM的Attention Mask(Seq2Seq模型用)
"""
class
UnifiedLanguageModel
(
UniLM_Mask
,
InputModel
):
"""带UniLM的Attention Mask的派生模型
UniLM: https://arxiv.org/abs/1905.03197
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
[
'with_mlm'
]
=
kwargs
.
get
(
'with_mlm'
)
or
True
super
(
UnifiedLanguageModel
,
self
).
__init__
(
*
args
,
**
kwargs
)
return
UnifiedLanguageModel
class
BERT
(
BERT_BASE
):
"""构建BERT模型
"""
def
__init__
(
self
,
max_position
,
# 序列最大长度
segment_vocab_size
=
2
,
# segment总数目
with_pool
=
False
,
# 是否包含Pool部分
with_nsp
=
False
,
# 是否包含NSP部分
with_mlm
=
False
,
# 是否包含MLM部分
custom_position_ids
=
False
,
# 是否自行传入位置id
custom_attention_mask
=
False
,
# 是否自行传入attention_mask
shared_segment_embeddings
=
False
,
# 若True,则segment跟token共用embedding
layer_norm_cond
=
None
,
# conditional layer_norm
layer_add_embs
=
None
,
# addtional_embeddng, 比如加入词性,音调,word粒度的自定义embedding
is_dropout
=
False
,
token_pad_ids
=
0
,
# 默认0是padding ids, 但是注意google的mt5padding不是0
**
kwargs
# 其余参数
):
super
(
BERT
,
self
).
__init__
(
**
kwargs
)
self
.
max_position
=
max_position
self
.
segment_vocab_size
=
segment_vocab_size
self
.
with_pool
=
with_pool
self
.
with_nsp
=
with_nsp
self
.
with_mlm
=
with_mlm
self
.
custom_position_ids
=
custom_position_ids
self
.
custom_attention_mask
=
custom_attention_mask
self
.
shared_segment_embeddings
=
shared_segment_embeddings
self
.
is_dropout
=
is_dropout
self
.
token_pad_ids
=
token_pad_ids
if
self
.
with_nsp
and
not
self
.
with_pool
:
self
.
with_pool
=
True
self
.
layer_norm_conds
=
layer_norm_cond
self
.
layer_add_embs
=
layer_add_embs
self
.
conditional_size
=
layer_norm_cond
.
weight
.
size
(
1
)
if
layer_norm_cond
is
not
None
else
None
self
.
embeddings
=
BertEmbeddings
(
self
.
vocab_size
,
self
.
embedding_size
,
self
.
hidden_size
,
self
.
max_position
,
self
.
segment_vocab_size
,
self
.
shared_segment_embeddings
,
self
.
dropout_rate
,
self
.
conditional_size
,
**
get_kw
(
BertEmbeddings
,
kwargs
))
kwargs
[
'max_position'
]
=
self
.
max_position
# 相对位置编码需要使用
layer
=
BertLayer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
,
**
get_kw
(
BertLayer
,
kwargs
))
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
if
layer_id
in
self
.
keep_hidden_layers
else
Identity
()
for
layer_id
in
range
(
self
.
num_hidden_layers
)])
if
self
.
with_pool
:
# Pooler部分(提取CLS向量)
self
.
pooler
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
)
self
.
pooler_activation
=
nn
.
Tanh
()
if
self
.
with_pool
is
True
else
get_activation
(
self
.
with_pool
)
if
self
.
with_nsp
:
# Next Sentence Prediction部分
# nsp的输入为pooled_output, 所以with_pool为True是使用nsp的前提条件
self
.
nsp
=
nn
.
Linear
(
self
.
hidden_size
,
2
)
else
:
self
.
pooler
=
None
self
.
pooler_activation
=
None
if
self
.
with_mlm
:
self
.
mlmDense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
)
self
.
transform_act_fn
=
get_activation
(
self
.
hidden_act
)
self
.
mlmLayerNorm
=
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-12
,
conditional_size
=
self
.
conditional_size
)
self
.
mlmDecoder
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
False
)
if
kwargs
.
get
(
'tie_emb_prj_weight'
)
is
True
:
self
.
mlmDecoder
.
weight
=
self
.
embeddings
.
word_embeddings
.
weight
self
.
mlmBias
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
vocab_size
))
self
.
mlmDecoder
.
bias
=
self
.
mlmBias
# 下述继承于BERT的有声明新的参数,在这里初始化不能统一初始化到
def
apply_embeddings
(
self
,
inputs
):
"""BERT的embedding是token、position、segment三者embedding之和
默认顺序是token_ids, segment_ids(若有), position_ids(若有), custom_attention_mask(若有), conditional_input(若有)
"""
token_ids
=
inputs
[
0
]
index_
=
1
if
self
.
segment_vocab_size
>
0
:
segment_ids
=
inputs
[
index_
]
index_
+=
1
else
:
segment_ids
=
None
if
self
.
custom_position_ids
:
# 暂未使用到,暂保留
position_ids
=
inputs
[
index_
]
index_
+=
1
else
:
position_ids
=
None
# 根据token_ids创建一个3D的attention mask矩阵,尺寸为[batch_size, 1, 1, to_seq_length],
# 目的是为了适配多头注意力机制,从而能广播到[batch_size, num_heads, from_seq_length, to_seq_length]尺寸
if
self
.
custom_attention_mask
:
attention_mask
=
inputs
[
index_
].
long
().
unsqueeze
(
1
).
unsqueeze
(
2
)
index_
+=
1
elif
(
not
token_ids
.
requires_grad
)
and
(
token_ids
.
dtype
in
{
torch
.
long
,
torch
.
int
}):
# 正常的token_ids
attention_mask
=
(
token_ids
!=
self
.
token_pad_ids
).
long
().
unsqueeze
(
1
).
unsqueeze
(
2
)
# 默认0为mask_value
if
self
.
token_pad_ids
<
0
:
token_ids
=
token_ids
*
attention_mask
[:,
0
,
0
,:]
else
:
# 自定义word_embedding,目前仅有VAT中使用
attention_mask
=
self
.
attention_mask_cache
self
.
attention_mask_cache
=
attention_mask
# 缓存上次用的attention_mask
self
.
compute_attention_bias
([
token_ids
,
segment_ids
])
# 根据lm或者unilm需要对mask做调整
if
self
.
attention_bias
is
not
None
:
attention_mask
=
attention_mask
*
self
.
attention_bias
# 不可访问padding
# attention_mask = self.attention_bias # 可以访问padding
# pytorch >= 1.5时候会导致StopIteration错误
# https://github.com/huggingface/transformers/issues/3936
# https://github.com/huggingface/transformers/issues/4189
# https://github.com/huggingface/transformers/issues/3936
try
:
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# 兼容fp16
except
StopIteration
:
attention_mask
=
attention_mask
.
to
(
dtype
=
torch
.
float32
)
# 对mask矩阵中,数值为0的转换成很大的负数,使得不需要attention的位置经过softmax后,分数趋近于0
# attention_mask = (1.0 - attention_mask) * -10000.0
# conditional layer_norm
if
self
.
layer_norm_conds
is
None
:
conditional_emb
=
None
else
:
conditional_emb
=
self
.
layer_norm_conds
(
inputs
[
index_
])
index_
+=
1
# addtional_embeddng, 比如加入词性,音调,word粒度的自定义embedding
if
isinstance
(
self
.
layer_add_embs
,
nn
.
Module
):
# 单个
additional_embs
=
[
self
.
layer_add_embs
(
inputs
[
index_
])]
index_
+=
1
elif
isinstance
(
self
.
layer_add_embs
,
(
tuple
,
list
)):
# 多个
additional_embs
=
[]
for
layer
in
self
.
layer_add_embs
:
assert
isinstance
(
layer
,
nn
.
Module
),
'Layer_add_embs element should be nn.Module'
additional_embs
.
append
(
layer
(
inputs
[
index_
]))
index_
+=
1
else
:
additional_embs
=
None
# 进入embedding层
hidden_states
=
self
.
embeddings
(
token_ids
,
segment_ids
,
conditional_emb
,
additional_embs
)
return
[
hidden_states
,
attention_mask
,
conditional_emb
]
+
inputs
[
index_
:]
def
apply_main_layers
(
self
,
inputs
):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
默认第一个是hidden_states, 第二个是attention_mask, 第三个是conditional_emb
"""
hidden_states
,
attention_mask
,
conditional_emb
=
inputs
[:
3
]
if
len
(
inputs
[
3
:])
>=
2
:
encoder_hidden_state
,
encoder_attention_mask
=
inputs
[
3
],
inputs
[
4
]
else
:
encoder_hidden_state
,
encoder_attention_mask
=
None
,
None
encoded_layers
=
[
hidden_states
]
# 添加embedding的输出
layer_inputs
=
[
hidden_states
,
attention_mask
,
conditional_emb
,
encoder_hidden_state
,
encoder_attention_mask
]
for
l_i
,
layer_module
in
enumerate
(
self
.
encoderLayer
):
layer_inputs
=
self
.
apply_on_layer_begin
(
l_i
,
layer_inputs
)
hidden_states
=
layer_module
(
*
layer_inputs
)
layer_inputs
[
0
]
=
hidden_states
layer_inputs
=
self
.
apply_on_layer_end
(
l_i
,
layer_inputs
)
if
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
if
not
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
return
[
encoded_layers
,
conditional_emb
]
def
apply_final_layers
(
self
,
inputs
):
"""根据剩余参数决定输出
"""
# 获取最后一层隐藏层的输出
encoded_layers
,
conditional_emb
=
inputs
sequence_output
=
encoded_layers
[
-
1
]
# 是否取最后一层输出
if
not
self
.
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
# 是否添加pool层
if
self
.
with_pool
:
pooled_output
=
self
.
pooler_activation
(
self
.
pooler
(
sequence_output
[:,
0
]))
else
:
pooled_output
=
None
# 是否添加nsp
if
self
.
with_pool
and
self
.
with_nsp
:
nsp_scores
=
self
.
nsp
(
pooled_output
)
else
:
nsp_scores
=
None
# 是否添加mlm
if
self
.
with_mlm
:
mlm_hidden_state
=
self
.
mlmDense
(
sequence_output
)
mlm_hidden_state
=
self
.
transform_act_fn
(
mlm_hidden_state
)
mlm_hidden_state
=
self
.
mlmLayerNorm
((
mlm_hidden_state
,
conditional_emb
))
mlm_scores
=
self
.
mlmDecoder
(
mlm_hidden_state
)
mlm_activation
=
get_activation
(
'linear'
if
self
.
with_mlm
is
True
else
self
.
with_mlm
)
mlm_scores
=
mlm_activation
(
mlm_scores
)
else
:
mlm_scores
=
None
outputs
=
[
value
for
value
in
[
encoded_layers
,
pooled_output
,
mlm_scores
,
nsp_scores
]
if
value
is
not
None
]
return
outputs
if
len
(
outputs
)
>
1
else
outputs
[
0
]
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
'bert'
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
f
'
{
prefix
}
.embeddings.word_embeddings.weight'
,
'cls.predictions.bias'
,
'cls.predictions.decoder.weight'
,
'cls.predictions.decoder.bias'
}:
return
self
.
load_embeddings
(
variable
)
elif
name
==
f
'
{
prefix
}
.embeddings.position_embeddings.weight'
:
return
self
.
load_pos_embeddings
(
variable
)
elif
name
==
'cls.seq_relationship.weight'
:
return
variable
.
T
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
'bert'
):
mapping
=
{
'embeddings.word_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.word_embeddings.weight'
,
'embeddings.position_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.position_embeddings.weight'
,
'embeddings.segment_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.token_type_embeddings.weight'
,
'embeddings.layerNorm.weight'
:
f
'
{
prefix
}
.embeddings.LayerNorm.weight'
,
'embeddings.layerNorm.bias'
:
f
'
{
prefix
}
.embeddings.LayerNorm.bias'
,
'pooler.weight'
:
f
'
{
prefix
}
.pooler.dense.weight'
,
'pooler.bias'
:
f
'
{
prefix
}
.pooler.dense.bias'
,
'nsp.weight'
:
'cls.seq_relationship.weight'
,
'nsp.bias'
:
'cls.seq_relationship.bias'
,
'mlmDense.weight'
:
'cls.predictions.transform.dense.weight'
,
'mlmDense.bias'
:
'cls.predictions.transform.dense.bias'
,
'mlmLayerNorm.weight'
:
'cls.predictions.transform.LayerNorm.weight'
,
'mlmLayerNorm.bias'
:
'cls.predictions.transform.LayerNorm.bias'
,
'mlmBias'
:
'cls.predictions.bias'
,
'mlmDecoder.weight'
:
'cls.predictions.decoder.weight'
,
'mlmDecoder.bias'
:
'cls.predictions.decoder.bias'
}
for
i
in
range
(
self
.
num_hidden_layers
):
prefix_i
=
f
'
{
prefix
}
.encoder.layer.%d.'
%
i
mapping
.
update
({
f
'encoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
prefix_i
+
'attention.self.query.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.q.bias'
:
prefix_i
+
'attention.self.query.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
prefix_i
+
'attention.self.key.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.k.bias'
:
prefix_i
+
'attention.self.key.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
prefix_i
+
'attention.self.value.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.v.bias'
:
prefix_i
+
'attention.self.value.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
prefix_i
+
'attention.output.dense.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.o.bias'
:
prefix_i
+
'attention.output.dense.bias'
,
f
'encoderLayer.
{
i
}
.layerNorm1.weight'
:
prefix_i
+
'attention.output.LayerNorm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm1.bias'
:
prefix_i
+
'attention.output.LayerNorm.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
prefix_i
+
'intermediate.dense.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.bias'
:
prefix_i
+
'intermediate.dense.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
prefix_i
+
'output.dense.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.bias'
:
prefix_i
+
'output.dense.bias'
,
f
'encoderLayer.
{
i
}
.layerNorm2.weight'
:
prefix_i
+
'output.LayerNorm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm2.bias'
:
prefix_i
+
'output.LayerNorm.bias'
})
return
mapping
class
ALBERT
(
BERT
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ALBERT
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
encoderLayer
=
nn
.
ModuleList
([
self
.
encoderLayer
[
0
]])
# 取上述的第一行
def
apply_main_layers
(
self
,
inputs
):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states
,
attention_mask
,
conditional_emb
=
inputs
[:
3
]
if
len
(
inputs
[
3
:])
>=
2
:
encoder_hidden_state
,
encoder_attention_mask
=
inputs
[
3
],
inputs
[
4
]
else
:
encoder_hidden_state
,
encoder_attention_mask
=
None
,
None
encoded_layers
=
[
hidden_states
]
# 添加embedding的输出
layer_inputs
=
[
hidden_states
,
attention_mask
,
conditional_emb
,
encoder_hidden_state
,
encoder_attention_mask
]
for
l_i
in
range
(
self
.
num_hidden_layers
):
layer_inputs
=
self
.
apply_on_layer_begin
(
l_i
,
layer_inputs
)
hidden_states
=
self
.
encoderLayer
[
0
](
*
layer_inputs
)
layer_inputs
[
0
]
=
hidden_states
layer_inputs
=
self
.
apply_on_layer_end
(
l_i
,
layer_inputs
)
if
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
if
not
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
return
[
encoded_layers
,
conditional_emb
]
def
variable_mapping
(
self
,
prefix
=
'albert'
):
mapping
=
{
'embeddings.word_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.word_embeddings.weight'
,
'embeddings.position_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.position_embeddings.weight'
,
'embeddings.segment_embeddings.weight'
:
f
'
{
prefix
}
.embeddings.token_type_embeddings.weight'
,
'embeddings.layerNorm.weight'
:
f
'
{
prefix
}
.embeddings.LayerNorm.weight'
,
'embeddings.layerNorm.bias'
:
f
'
{
prefix
}
.embeddings.LayerNorm.bias'
,
'embeddings.embedding_hidden_mapping_in.weight'
:
f
'
{
prefix
}
.encoder.embedding_hidden_mapping_in.weight'
,
'embeddings.embedding_hidden_mapping_in.bias'
:
f
'
{
prefix
}
.encoder.embedding_hidden_mapping_in.bias'
,
'pooler.weight'
:
f
'
{
prefix
}
.pooler.weight'
,
'pooler.bias'
:
f
'
{
prefix
}
.pooler.bias'
,
'nsp.weight'
:
'sop_classifier.classifier.weight'
,
# 用名字nsp来替换sop
'nsp.bias'
:
'sop_classifier.classifier.bias'
,
'mlmDense.weight'
:
'predictions.dense.weight'
,
'mlmDense.bias'
:
'predictions.dense.bias'
,
'mlmLayerNorm.weight'
:
'predictions.LayerNorm.weight'
,
'mlmLayerNorm.bias'
:
'predictions.LayerNorm.bias'
,
'mlmBias'
:
'predictions.bias'
,
'mlmDecoder.weight'
:
'predictions.decoder.weight'
,
'mlmDecoder.bias'
:
'predictions.decoder.bias'
}
i
=
0
prefix_i
=
f
'
{
prefix
}
.encoder.albert_layer_groups.
{
i
}
.albert_layers.
{
i
}
.'
mapping
.
update
({
f
'encoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
prefix_i
+
'attention.query.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.q.bias'
:
prefix_i
+
'attention.query.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
prefix_i
+
'attention.key.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.k.bias'
:
prefix_i
+
'attention.key.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
prefix_i
+
'attention.value.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.v.bias'
:
prefix_i
+
'attention.value.bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
prefix_i
+
'attention.dense.weight'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.o.bias'
:
prefix_i
+
'attention.dense.bias'
,
f
'encoderLayer.
{
i
}
.layerNorm1.weight'
:
prefix_i
+
'attention.LayerNorm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm1.bias'
:
prefix_i
+
'attention.LayerNorm.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
prefix_i
+
'ffn.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.bias'
:
prefix_i
+
'ffn.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
prefix_i
+
'ffn_output.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.bias'
:
prefix_i
+
'ffn_output.bias'
,
f
'encoderLayer.
{
i
}
.layerNorm2.weight'
:
prefix_i
+
'full_layer_layer_norm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm2.bias'
:
prefix_i
+
'full_layer_layer_norm.bias'
})
return
mapping
def
load_variable
(
self
,
state_dict
,
name
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
'albert.embeddings.word_embeddings.weight'
,
'predictions.bias'
,
'predictions.decoder.weight'
,
'predictions.decoder.bias'
}:
return
self
.
load_embeddings
(
variable
)
elif
name
==
'albert.embeddings.position_embeddings.weight'
:
return
self
.
load_pos_embeddings
(
variable
)
elif
name
==
'sop_classifier.classifier.weight'
:
return
variable
.
T
else
:
return
variable
class
ALBERT_Unshared
(
ALBERT
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
ALBERT_Unshared
).
__init__
(
*
args
,
**
kwargs
)
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
self
.
encoderLayer
[
0
])
for
_
in
range
(
self
.
num_hidden_layers
)])
def
apply_main_layers
(
self
,
inputs
):
"""BERT的主体是基于Self-Attention的模块
顺序:Att --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states
,
attention_mask
,
conditional_emb
=
inputs
if
len
(
inputs
[
3
:])
>=
2
:
encoder_hidden_state
,
encoder_attention_mask
=
inputs
[
3
],
inputs
[
4
]
else
:
encoder_hidden_state
,
encoder_attention_mask
=
None
,
None
encoded_layers
=
[
hidden_states
]
# 添加embedding的输出
layer_inputs
=
[
hidden_states
,
attention_mask
,
conditional_emb
,
encoder_hidden_state
,
encoder_attention_mask
]
for
i
in
range
(
self
.
num_hidden_layers
):
layer_inputs
=
self
.
apply_on_layer_begin
(
i
,
layer_inputs
)
hidden_states
=
self
.
encoderLayer
[
i
](
*
layer_inputs
)
layer_inputs
[
0
]
=
hidden_states
layer_inputs
=
self
.
apply_on_layer_end
(
i
,
layer_inputs
)
if
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
if
not
self
.
output_all_encoded_layers
:
encoded_layers
.
append
(
hidden_states
)
return
[
encoded_layers
,
conditional_emb
]
class
NEZHA
(
BERT
):
"""华为推出的NAZHA模型
链接:https://arxiv.org/abs/1909.00204
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
# p_bias来控制embedding阶段无pos_embedding, max_relative_position默认取64
kwargs
.
update
({
'p_bias'
:
'typical_relative'
,
'max_relative_position'
:
kwargs
.
get
(
'max_relative_position'
,
64
)})
super
(
NEZHA
,
self
).
__init__
(
*
args
,
**
kwargs
)
class
RoFormer
(
BERT
):
"""旋转式位置编码的BERT模型
链接:https://kexue.fm/archives/8265
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
({
'p_bias'
:
'rotary'
})
super
(
RoFormer
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
'roformer'
):
return
super
().
load_variable
(
state_dict
,
name
,
prefix
)
def
variable_mapping
(
self
,
prefix
=
'roformer'
):
mapping
=
super
().
variable_mapping
(
prefix
)
del
mapping
[
'embeddings.position_embeddings.weight'
]
# 没有位置编码
return
mapping
class
RoFormerV2
(
RoFormer
):
"""RoFormerV2
改动:去掉bias,简化Norm,优化初始化等。目前初始化暂时还用的bert的初始化,finetune不受影响
"""
@
delete_arguments
(
'with_pool'
,
'with_nsp'
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
({
'p_bias'
:
'rotary'
,
'weight'
:
False
,
'bias'
:
False
,
'norm_mode'
:
'rmsnorm'
})
super
(
RoFormerV2
,
self
).
__init__
(
*
args
,
**
kwargs
)
if
self
.
with_mlm
:
del
self
.
mlmLayerNorm
del
self
.
mlmBias
del
self
.
mlmDense
self
.
mlmDecoder
.
register_parameter
(
'bias'
,
None
)
def
variable_mapping
(
self
,
prefix
=
'roformer'
):
mapping
=
super
().
variable_mapping
(
prefix
)
mapping_new
=
{}
for
k
,
v
in
mapping
.
items
():
if
(
not
re
.
search
(
'bias|layernorm'
,
k
.
lower
()))
and
(
not
re
.
search
(
'bias|layernorm'
,
v
.
lower
())):
mapping_new
[
k
]
=
v
return
mapping_new
def
apply_final_layers
(
self
,
inputs
):
"""根据剩余参数决定输出
"""
# 获取最后一层隐藏层的输出
encoded_layers
,
conditional_emb
=
inputs
sequence_output
=
encoded_layers
[
-
1
]
# 是否取最后一层输出
if
not
self
.
output_all_encoded_layers
:
encoded_layers
=
encoded_layers
[
-
1
]
# 是否添加mlm
if
self
.
with_mlm
:
mlm_scores
=
self
.
mlmDecoder
(
sequence_output
)
else
:
mlm_scores
=
None
outputs
=
[
value
for
value
in
[
encoded_layers
,
mlm_scores
]
if
value
is
not
None
]
return
outputs
if
len
(
outputs
)
>
1
else
outputs
[
0
]
class
GAU_alpha
(
RoFormerV2
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
({
'p_bias'
:
'rotary'
,
'weight'
:
False
,
'bias'
:
False
,
'norm_mode'
:
'rmsnorm'
,
'normalization'
:
'softmax_plus'
})
super
().
__init__
(
*
args
,
**
kwargs
)
layer
=
self
.
GAU_Layer
(
**
kwargs
)
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
if
layer_id
in
self
.
keep_hidden_layers
else
Identity
()
for
layer_id
in
range
(
self
.
num_hidden_layers
)])
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
variable
=
state_dict
[
name
]
return
self
.
load_embeddings
(
variable
)
if
name
in
{
'embeddings.word_embeddings.weight'
,
'mlmDecoder.weight'
}
else
variable
def
variable_mapping
(
self
,
prefix
=
''
):
'''在convert脚本里已经把key转成bert4torch可用的
'''
return
{
k
:
k
for
k
,
_
in
self
.
named_parameters
()}
class
GAU_Layer
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
gau
=
GatedAttentionUnit
(
**
kwargs
)
self
.
dropout1
=
nn
.
Dropout
(
kwargs
.
get
(
'dropout_rate'
))
self
.
layerNorm1
=
LayerNorm
(
**
kwargs
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
conditional_emb
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
gau_hidden_states
=
self
.
gau
(
hidden_states
,
attention_mask
)
hidden_states
=
hidden_states
+
self
.
dropout1
(
gau_hidden_states
)
hidden_states
=
self
.
layerNorm1
((
hidden_states
,
conditional_emb
))
return
hidden_states
class
ELECTRA
(
BERT
):
"""Google推出的ELECTRA模型
链接:https://arxiv.org/abs/2003.10555
"""
@
insert_arguments
(
with_discriminator
=
False
)
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
max_position
,
**
kwargs
):
super
(
ELECTRA
,
self
).
__init__
(
max_position
,
**
kwargs
)
if
self
.
with_discriminator
:
self
.
dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
hidden_size
)
self
.
dense_act
=
get_activation
(
self
.
hidden_act
)
self
.
dense_prediction
=
nn
.
Linear
(
self
.
hidden_size
,
1
)
self
.
dense_prediction_act
=
get_activation
(
'sigmoid'
)
if
self
.
with_discriminator
is
True
else
get_activation
(
self
.
with_discriminator
)
def
apply_final_layers
(
self
,
inputs
):
hidden_states
=
super
().
apply_final_layers
(
inputs
)
# 仅有hidden_state一项输出
if
self
.
with_discriminator
:
logit
=
self
.
dense_act
(
self
.
dense
(
hidden_states
))
return
[
hidden_states
,
self
.
dense_prediction_act
(
self
.
dense_prediction
(
logit
))]
else
:
return
hidden_states
def
load_variable
(
self
,
state_dict
,
name
):
"""加载单个变量的函数
"""
return
super
().
load_variable
(
state_dict
,
name
,
prefix
=
'electra'
)
def
variable_mapping
(
self
):
mapping
=
super
(
ELECTRA
,
self
).
variable_mapping
(
prefix
=
'electra'
)
mapping
.
update
({
'dense.weight'
:
'discriminator_predictions.dense.weight'
,
'dense.bias'
:
'discriminator_predictions.dense.bias'
,
'dense_prediction.weight'
:
'discriminator_predictions.dense_prediction.weight'
,
'dense_prediction.bias'
:
'discriminator_predictions.dense_prediction.bias'
}
)
for
del_key
in
[
'pooler.weight'
,
'pooler.bias'
,
'nsp.weight'
,
'nsp.bias'
,
'mlmDense.weight'
,
'mlmDense.bias'
,
'mlmLayerNorm.weight'
,
'mlmLayerNorm.bias'
,
'mlmBias'
,
'mlmDecoder.weight'
,
'mlmDecoder.bias'
]:
del
mapping
[
del_key
]
return
mapping
class
Encoder
(
BERT
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
[
'vocab_size'
]
=
kwargs
.
get
(
'src_vocab_size'
,
kwargs
[
'vocab_size'
])
super
().
__init__
(
*
args
,
**
kwargs
)
# encoder需要返回encoder_attention_mask
self
.
encoder_attention_mask
=
None
def
forward
(
self
,
inputs
):
"""因为encoder需要返回encoder_attention_mask,因此这里从新定义一下,多返回一个参数
"""
# Embedding
outputs
=
self
.
apply_embeddings
(
inputs
)
encoder_attention_mask
=
[
outputs
[
1
]]
# Main
outputs
=
self
.
apply_main_layers
(
outputs
)
# Final
outputs
=
self
.
apply_final_layers
(
outputs
)
return
([
outputs
]
if
isinstance
(
outputs
,
torch
.
Tensor
)
else
outputs
)
+
encoder_attention_mask
class
Decoder
(
LM_Mask
,
BERT
):
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
*
args
,
with_lm
=
True
,
tie_emb_prj_weight
=
True
,
**
kwargs
):
kwargs
[
'vocab_size'
]
=
kwargs
.
get
(
'tgt_vocab_size'
,
kwargs
[
'vocab_size'
])
kwargs
[
'is_decoder'
]
=
True
# 标记是decoder
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
decoderLayer
=
self
.
encoderLayer
del
self
.
encoderLayer
self
.
with_lm
=
with_lm
# 从hidden_states映射到logit
if
self
.
with_lm
:
self
.
final_dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
False
)
if
tie_emb_prj_weight
:
# decoder底层的embedding和顶层的全连接共享
self
.
final_dense
.
weight
=
self
.
embeddings
.
word_embeddings
.
weight
self
.
x_logit_scale
=
(
self
.
hidden_size
**
-
0.5
)
else
:
self
.
x_logit_scale
=
1.
def
apply_main_layers
(
self
,
inputs
):
"""Dencoder主体是基于Self-Attention、Cross-Attention的模块
顺序:Att1 --> Add --> LN --> Att2 --> Add --> LN --> FFN --> Add --> LN
"""
hidden_states
,
attention_mask
,
conditional_emb
,
encoder_hidden_state
,
encoder_attention_mask
=
inputs
[:
5
]
decoded_layers
=
[
hidden_states
]
# 添加embedding的输出
layer_inputs
=
[
hidden_states
,
attention_mask
,
conditional_emb
,
encoder_hidden_state
,
encoder_attention_mask
]
for
i
,
layer_module
in
enumerate
(
self
.
decoderLayer
):
layer_inputs
=
self
.
apply_on_layer_begin
(
i
,
layer_inputs
)
hidden_states
=
layer_module
(
*
layer_inputs
)
layer_inputs
[
0
]
=
hidden_states
layer_inputs
=
self
.
apply_on_layer_end
(
i
,
layer_inputs
)
if
self
.
output_all_encoded_layers
:
decoded_layers
.
append
(
hidden_states
)
if
not
self
.
output_all_encoded_layers
:
decoded_layers
.
append
(
hidden_states
)
return
[
decoded_layers
,
conditional_emb
]
def
apply_final_layers
(
self
,
inputs
):
outputs
=
[]
hidden_states
=
super
().
apply_final_layers
(
inputs
)
# outputs为decoder顶层的hidden_states [btz, seq_len, hdsz]
outputs
.
append
(
hidden_states
)
if
self
.
with_lm
:
logits
=
self
.
final_dense
(
hidden_states
)
*
self
.
x_logit_scale
# outputs为[btz, seq_len, vocab_size]的logits
activation
=
get_activation
(
'linear'
if
self
.
with_lm
is
True
else
self
.
with_lm
)
# 添加激活,一般是线性激活或softmax
logits
=
activation
(
logits
)
outputs
.
append
(
logits
)
return
outputs
def
variable_mapping
(
self
,
prefix
=
'bert'
):
raw_mapping
=
super
().
variable_mapping
(
prefix
)
mapping
=
{}
for
k
,
v
in
raw_mapping
.
items
():
mapping
[
k
.
replace
(
'encoderLayer'
,
'decoderLayer'
)]
=
v
# for i in range(self.num_hidden_layers):
# prefix_i = f'{prefix}.encoder.layer.%d.' % i
# mapping.update({
# f'decoderLayer.{i}.crossAttention.q.weight': prefix_i + 'crossattention.self.query.weight',
# f'decoderLayer.{i}.crossAttention.q.bias': prefix_i + 'crossattention.self.query.bias',
# f'decoderLayer.{i}.crossAttention.k.weight': prefix_i + 'crossattention.self.key.weight',
# f'decoderLayer.{i}.crossAttention.k.bias': prefix_i + 'crossattention.self.key.bias',
# f'decoderLayer.{i}.crossAttention.v.weight': prefix_i + 'crossattention.self.value.weight',
# f'decoderLayer.{i}.crossAttention.v.bias': prefix_i + 'crossattention.self.value.bias',
# f'decoderLayer.{i}.crossAttention.o.weight': prefix_i + 'crossattention.output.dense.weight',
# f'decoderLayer.{i}.crossAttention.o.bias': prefix_i + 'crossattention.output.dense.bias',
# f'decoderLayer.{i}.layerNorm3.weight': prefix_i + 'crossattention.output.LayerNorm.weight',
# f'decoderLayer.{i}.layerNorm3.bias': prefix_i + 'crossattention.output.LayerNorm.bias'
# })
return
mapping
class
Transformer
(
BERT_BASE
):
'''encoder-decoder结构
'''
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
*
args
,
tie_emb_src_tgt_weight
=
False
,
**
kwargs
):
super
(
Transformer
,
self
).
__init__
(
*
args
,
**
kwargs
)
# encoder
self
.
encoder
=
Encoder
(
*
args
,
**
kwargs
)
self
.
encoder
.
build
(
**
kwargs
)
# decoder
self
.
decoder
=
Decoder
(
*
args
,
**
kwargs
)
self
.
decoder
.
build
(
**
kwargs
)
if
tie_emb_src_tgt_weight
:
# encoder和decoder的embedding权重共享
assert
self
.
encoder
.
vocab_size
==
self
.
decoder
.
vocab_size
,
"To share word embedding, the vocab size of src/tgt shall be the same."
self
.
encoder
.
embeddings
.
word_embeddings
.
weight
=
self
.
decoder
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
inputs
):
"""定义模型的执行流程
"""
encoder_input
,
decoder_input
=
inputs
[:
2
]
# encoder
# encoder_emb = self.encoder.apply_embeddings(encoder_input)
# encode_outputs = self.encoder.apply_main_layers(encoder_emb)
# encoder_hidden_state = self.encoder.apply_final_layers(encode_outputs)
# encoder_attention_mask = encoder_emb[1]
encoder_hidden_state
,
encoder_attention_mask
=
self
.
encoder
(
encoder_input
)
# decoder
# decoder_emb = self.decoder.apply_embeddings(decoder_input)
# decoder_outputs = self.decoder.apply_main_layers([*decoder_emb, encoder_hidden_state, encoder_attention_mask])
# decoder_outputs = self.decoder.apply_final_layers(decoder_outputs) # [hidden_states, logits]
decoder_outputs
=
self
.
decoder
(
decoder_input
+
[
encoder_hidden_state
,
encoder_attention_mask
])
return
[
encoder_hidden_state
]
+
decoder_outputs
# 输出encoder_hidden_state和decoder_hidden_state,以应对一些多任务情况
class
BART
(
Transformer
):
'''encoder-decoder结构
'''
def
__init__
(
self
,
*
args
,
tie_emb_src_tgt_weight
=
True
,
**
kwargs
):
super
(
BART
,
self
).
__init__
(
*
args
,
tie_emb_src_tgt_weight
=
tie_emb_src_tgt_weight
,
**
kwargs
)
self
.
tie_emb_src_tgt_weight
=
tie_emb_src_tgt_weight
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
'shared.weight'
,
'encoder.embed_tokens.weight'
,
'decoder.embed_tokens.weight'
,
}:
return
self
.
load_embeddings
(
variable
)
elif
name
in
{
'encoder.embed_positions.weight'
,
'decoder.embed_positions.weight'
}:
return
self
.
load_pos_embeddings
(
variable
)
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
''
):
# 查看check_point发现'shared.weight'
mapping
=
{
'encoder.embeddings.word_embeddings.weight'
:
'shared.weight'
if
self
.
tie_emb_src_tgt_weight
else
'encoder.embed_tokens.weight'
,
'encoder.embeddings.position_embeddings.weight'
:
'encoder.embed_positions.weight'
,
'encoder.embeddings.layerNorm.weight'
:
'encoder.layernorm_embedding.weight'
,
'encoder.embeddings.layerNorm.bias'
:
'encoder.layernorm_embedding.bias'
,
'decoder.embeddings.word_embeddings.weight'
:
'shared.weight'
if
self
.
tie_emb_src_tgt_weight
else
'decoder.embed_tokens.weight'
,
'decoder.embeddings.position_embeddings.weight'
:
'decoder.embed_positions.weight'
,
'decoder.embeddings.layerNorm.weight'
:
'decoder.layernorm_embedding.weight'
,
'decoder.embeddings.layerNorm.bias'
:
'decoder.layernorm_embedding.bias'
,
}
for
i
in
range
(
self
.
num_hidden_layers
):
mapping
.
update
(
{
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
f
'encoder.layers.
{
i
}
.self_attn.q_proj.weight'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.q.bias'
:
f
'encoder.layers.
{
i
}
.self_attn.q_proj.bias'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
f
'encoder.layers.
{
i
}
.self_attn.k_proj.weight'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.k.bias'
:
f
'encoder.layers.
{
i
}
.self_attn.k_proj.bias'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
f
'encoder.layers.
{
i
}
.self_attn.v_proj.weight'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.v.bias'
:
f
'encoder.layers.
{
i
}
.self_attn.v_proj.bias'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
f
'encoder.layers.
{
i
}
.self_attn.out_proj.weight'
,
f
'encoder.encoderLayer.
{
i
}
.multiHeadAttention.o.bias'
:
f
'encoder.layers.
{
i
}
.self_attn.out_proj.bias'
,
f
'encoder.encoderLayer.
{
i
}
.layerNorm1.weight'
:
f
'encoder.layers.
{
i
}
.self_attn_layer_norm.weight'
,
f
'encoder.encoderLayer.
{
i
}
.layerNorm1.bias'
:
f
'encoder.layers.
{
i
}
.self_attn_layer_norm.bias'
,
f
'encoder.encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'encoder.layers.
{
i
}
.fc1.weight'
,
f
'encoder.encoderLayer.
{
i
}
.feedForward.intermediateDense.bias'
:
f
'encoder.layers.
{
i
}
.fc1.bias'
,
f
'encoder.encoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
f
'encoder.layers.
{
i
}
.fc2.weight'
,
f
'encoder.encoderLayer.
{
i
}
.feedForward.outputDense.bias'
:
f
'encoder.layers.
{
i
}
.fc2.bias'
,
f
'encoder.encoderLayer.
{
i
}
.layerNorm2.weight'
:
f
'encoder.layers.
{
i
}
.final_layer_norm.weight'
,
f
'encoder.encoderLayer.
{
i
}
.layerNorm2.bias'
:
f
'encoder.layers.
{
i
}
.final_layer_norm.bias'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
f
'decoder.layers.
{
i
}
.self_attn.q_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.q.bias'
:
f
'decoder.layers.
{
i
}
.self_attn.q_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
f
'decoder.layers.
{
i
}
.self_attn.k_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.k.bias'
:
f
'decoder.layers.
{
i
}
.self_attn.k_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
f
'decoder.layers.
{
i
}
.self_attn.v_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.v.bias'
:
f
'decoder.layers.
{
i
}
.self_attn.v_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
f
'decoder.layers.
{
i
}
.self_attn.out_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.multiHeadAttention.o.bias'
:
f
'decoder.layers.
{
i
}
.self_attn.out_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm1.weight'
:
f
'decoder.layers.
{
i
}
.self_attn_layer_norm.weight'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm1.bias'
:
f
'decoder.layers.
{
i
}
.self_attn_layer_norm.bias'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.q.weight'
:
f
'decoder.layers.
{
i
}
.encoder_attn.q_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.q.bias'
:
f
'decoder.layers.
{
i
}
.encoder_attn.q_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.k.weight'
:
f
'decoder.layers.
{
i
}
.encoder_attn.k_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.k.bias'
:
f
'decoder.layers.
{
i
}
.encoder_attn.k_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.v.weight'
:
f
'decoder.layers.
{
i
}
.encoder_attn.v_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.v.bias'
:
f
'decoder.layers.
{
i
}
.encoder_attn.v_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.o.weight'
:
f
'decoder.layers.
{
i
}
.encoder_attn.out_proj.weight'
,
f
'decoder.decoderLayer.
{
i
}
.crossAttention.o.bias'
:
f
'decoder.layers.
{
i
}
.encoder_attn.out_proj.bias'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm3.weight'
:
f
'decoder.layers.
{
i
}
.encoder_attn_layer_norm.weight'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm3.bias'
:
f
'decoder.layers.
{
i
}
.encoder_attn_layer_norm.bias'
,
f
'decoder.decoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'decoder.layers.
{
i
}
.fc1.weight'
,
f
'decoder.decoderLayer.
{
i
}
.feedForward.intermediateDense.bias'
:
f
'decoder.layers.
{
i
}
.fc1.bias'
,
f
'decoder.decoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
f
'decoder.layers.
{
i
}
.fc2.weight'
,
f
'decoder.decoderLayer.
{
i
}
.feedForward.outputDense.bias'
:
f
'decoder.layers.
{
i
}
.fc2.bias'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm2.weight'
:
f
'decoder.layers.
{
i
}
.final_layer_norm.weight'
,
f
'decoder.decoderLayer.
{
i
}
.layerNorm2.bias'
:
f
'decoder.layers.
{
i
}
.final_layer_norm.bias'
})
return
mapping
class
T5_Encoder
(
Encoder
):
@
insert_arguments
(
version
=
't5.1.0'
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
({
'p_bias'
:
't5_relative'
,
'relative_attention_num_buckets'
:
kwargs
.
get
(
'relative_attention_num_buckets'
),
'version'
:
self
.
version
,
'bias'
:
False
,
'norm_mode'
:
'rmsnorm'
})
# p_bias来控制embedding阶段无pos_embedding,t5不使用bias,并且使用rmsnorm
super
().
__init__
(
*
args
,
**
kwargs
)
del
self
.
embeddings
.
layerNorm
# t5的layernorm都在前面,因此重新定义了下
layer
=
T5Layer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
,
**
get_kw
(
BertLayer
,
kwargs
))
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
self
.
num_hidden_layers
)])
# 把第二层后的相对位置编码的权重绑定到第一层上,变相实现仅由第一层计算
for
i
in
range
(
1
,
self
.
num_hidden_layers
):
self
.
encoderLayer
[
i
].
multiHeadAttention
.
relative_positions_encoding
.
weight
=
self
.
encoderLayer
[
0
].
multiHeadAttention
.
relative_positions_encoding
.
weight
self
.
final_layer_norm
=
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-12
,
conditional_size
=
self
.
conditional_size
,
bias
=
False
,
mode
=
'rmsnorm'
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
def
apply_final_layers
(
self
,
inputs
):
hidden_states
=
super
().
apply_final_layers
(
inputs
)
return
self
.
dropout
(
self
.
final_layer_norm
([
hidden_states
]))
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
'encoder.embed_tokens.weight'
,
'shared.weight'
}:
return
self
.
load_embeddings
(
variable
)
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
''
):
# 查看check_point发现'shared.weight'
mapping
=
{
f
'
{
prefix
}
embeddings.word_embeddings.weight'
:
'encoder.embed_tokens.weight'
,
f
'
{
prefix
}
encoderLayer.0.multiHeadAttention.relative_positions_encoding.weight'
:
'encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'
,
f
'
{
prefix
}
final_layer_norm.weight'
:
'encoder.final_layer_norm.weight'
}
for
i
in
range
(
self
.
num_hidden_layers
):
mapping
.
update
(
{
f
'
{
prefix
}
encoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
f
'encoder.block.
{
i
}
.layer.0.SelfAttention.q.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
f
'encoder.block.
{
i
}
.layer.0.SelfAttention.k.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
f
'encoder.block.
{
i
}
.layer.0.SelfAttention.v.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
f
'encoder.block.
{
i
}
.layer.0.SelfAttention.o.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.layerNorm1.weight'
:
f
'encoder.block.
{
i
}
.layer.0.layer_norm.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
f
'encoder.block.
{
i
}
.layer.1.DenseReluDense.wo.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.layerNorm2.weight'
:
f
'encoder.block.
{
i
}
.layer.1.layer_norm.weight'
,
})
if
self
.
version
.
endswith
(
't5.1.0'
):
mapping
.
update
({
f
'
{
prefix
}
encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'encoder.block.
{
i
}
.layer.1.DenseReluDense.wi.weight'
})
elif
self
.
version
.
endswith
(
't5.1.1'
):
mapping
.
update
({
f
'
{
prefix
}
encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'encoder.block.
{
i
}
.layer.1.DenseReluDense.wi_0.weight'
,
f
'
{
prefix
}
encoderLayer.
{
i
}
.feedForward.intermediateDense1.weight'
:
f
'encoder.block.
{
i
}
.layer.1.DenseReluDense.wi_1.weight'
})
return
mapping
class
T5_Decoder
(
Decoder
):
@
insert_arguments
(
version
=
't5.1.0'
)
def
__init__
(
self
,
*
args
,
**
kwargs
):
kwargs
.
update
({
'p_bias'
:
't5_relative'
,
'relative_attention_num_buckets'
:
kwargs
.
get
(
'relative_attention_num_buckets'
),
'version'
:
self
.
version
,
'bias'
:
False
,
'norm_mode'
:
'rmsnorm'
})
# p_bias来控制embedding阶段无pos_embedding,t5不使用bias,并且使用rmsnorm
super
().
__init__
(
*
args
,
**
kwargs
)
del
self
.
embeddings
.
layerNorm
# t5的layernorm都在前面,因此重新定义了下
layer
=
T5Layer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
,
is_decoder
=
True
,
**
get_kw
(
BertLayer
,
kwargs
))
self
.
decoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
for
_
in
range
(
self
.
num_hidden_layers
)])
# 把第二层后的相对位置编码的权重绑定到第一层上,变相实现仅由第一层计算
for
i
in
range
(
1
,
self
.
num_hidden_layers
):
self
.
decoderLayer
[
i
].
multiHeadAttention
.
relative_positions_encoding
.
weight
=
self
.
decoderLayer
[
0
].
multiHeadAttention
.
relative_positions_encoding
.
weight
self
.
final_layer_norm
=
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-12
,
conditional_size
=
self
.
conditional_size
,
bias
=
False
,
mode
=
'rmsnorm'
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
def
apply_final_layers
(
self
,
inputs
):
inputs
[
0
][
1
]
=
self
.
dropout
(
self
.
final_layer_norm
([
inputs
[
0
][
1
]]))
# 在转logit前把最后一层的hidden_states加layernorm
return
super
().
apply_final_layers
(
inputs
)
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
f
'decoder.embed_tokens.weight'
,
'lm_head.weight'
,
'shared.weight'
}:
return
self
.
load_embeddings
(
variable
)
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
''
):
# 查看check_point发现'shared.weight'
mapping
=
{
f
'
{
prefix
}
embeddings.word_embeddings.weight'
:
'decoder.embed_tokens.weight'
,
f
'
{
prefix
}
decoderLayer.0.multiHeadAttention.relative_positions_encoding.weight'
:
'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight'
,
f
'
{
prefix
}
final_layer_norm.weight'
:
'decoder.final_layer_norm.weight'
,
f
'
{
prefix
}
final_dense.weight'
:
'lm_head.weight'
}
for
i
in
range
(
self
.
num_hidden_layers
):
mapping
.
update
(
{
f
'
{
prefix
}
decoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
f
'decoder.block.
{
i
}
.layer.0.SelfAttention.q.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
f
'decoder.block.
{
i
}
.layer.0.SelfAttention.k.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
f
'decoder.block.
{
i
}
.layer.0.SelfAttention.v.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
f
'decoder.block.
{
i
}
.layer.0.SelfAttention.o.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.layerNorm1.weight'
:
f
'decoder.block.
{
i
}
.layer.0.layer_norm.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.crossAttention.q.weight'
:
f
'decoder.block.
{
i
}
.layer.1.EncDecAttention.q.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.crossAttention.k.weight'
:
f
'decoder.block.
{
i
}
.layer.1.EncDecAttention.k.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.crossAttention.v.weight'
:
f
'decoder.block.
{
i
}
.layer.1.EncDecAttention.v.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.crossAttention.o.weight'
:
f
'decoder.block.
{
i
}
.layer.1.EncDecAttention.o.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.layerNorm3.weight'
:
f
'decoder.block.
{
i
}
.layer.1.layer_norm.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
f
'decoder.block.
{
i
}
.layer.2.DenseReluDense.wo.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.layerNorm2.weight'
:
f
'decoder.block.
{
i
}
.layer.2.layer_norm.weight'
,
})
if
self
.
version
.
endswith
(
't5.1.0'
):
mapping
.
update
({
f
'
{
prefix
}
decoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'decoder.block.
{
i
}
.layer.2.DenseReluDense.wi.weight'
})
elif
self
.
version
.
endswith
(
't5.1.1'
):
mapping
.
update
({
f
'
{
prefix
}
decoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
f
'decoder.block.
{
i
}
.layer.2.DenseReluDense.wi_0.weight'
,
f
'
{
prefix
}
decoderLayer.
{
i
}
.feedForward.intermediateDense1.weight'
:
f
'decoder.block.
{
i
}
.layer.2.DenseReluDense.wi_1.weight'
})
return
mapping
class
T5
(
Transformer
):
"""Google的T5模型(Encoder-Decoder)
"""
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
*
args
,
tie_emb_src_tgt_weight
=
True
,
**
kwargs
):
super
(
T5
,
self
).
__init__
(
*
args
,
**
kwargs
)
self
.
tie_emb_src_tgt_weight
=
tie_emb_src_tgt_weight
# encoder
self
.
encoder
=
T5_Encoder
(
*
args
,
**
kwargs
)
self
.
encoder
.
build
(
**
kwargs
)
# decoder
self
.
decoder
=
T5_Decoder
(
*
args
,
**
kwargs
)
self
.
decoder
.
build
(
**
kwargs
)
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
'shared.weight'
,
'encoder.embed_tokens.weight'
,
'decoder.embed_tokens.weight'
,
'lm_head.weight'
}:
return
self
.
load_embeddings
(
variable
)
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
''
):
mapping
=
self
.
encoder
.
variable_mapping
(
prefix
=
'encoder.'
)
mapping
.
update
(
self
.
decoder
.
variable_mapping
(
prefix
=
'decoder.'
))
if
self
.
tie_emb_src_tgt_weight
:
mapping
.
update
({
'encoder.embeddings.word_embeddings.weight'
:
'shared.weight'
,
'decoder.embeddings.word_embeddings.weight'
:
'shared.weight'
})
return
mapping
class
GPT
(
LM_Mask
,
BERT
):
"""构建GPT模型
链接:https://github.com/openai/finetune-transformer-lm
"""
@
insert_arguments
(
final_activation
=
'softmax'
)
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
max_position
,
**
kwargs
):
"""GPT的embedding是token、position、segment三者embedding之和,跟BERT的主要区别是三者相加之后没有加LayerNormalization层。
使用LM_Mask实现预训练ckpt中的bias参数,最后的全连接层由于和embedding层权重一致,因此直接从word_embedding取
"""
super
(
GPT
,
self
).
__init__
(
max_position
,
**
kwargs
)
del
self
.
embeddings
.
layerNorm
self
.
dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
False
)
self
.
dense
.
weight
=
self
.
embeddings
.
word_embeddings
.
weight
self
.
final_activation
=
get_activation
(
self
.
final_activation
)
def
apply_final_layers
(
self
,
inputs
):
hidden_state
=
super
().
apply_final_layers
(
inputs
)
logit
=
self
.
dense
(
hidden_state
)
return
self
.
final_activation
(
logit
)
def
load_variable
(
self
,
state_dict
,
name
):
return
super
(
GPT
,
self
).
load_variable
(
state_dict
,
name
,
prefix
=
'gpt'
)
def
variable_mapping
(
self
):
"""映射到GPT权重格式
"""
mapping
=
super
(
GPT
,
self
).
variable_mapping
(
prefix
=
'gpt'
)
return
mapping
class
GPT2
(
LM_Mask
,
BERT
):
"""构建GPT模型
链接:https://github.com/openai/finetune-transformer-lm
"""
@
insert_arguments
(
final_activation
=
'softmax'
)
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
max_position
,
**
kwargs
):
"""GPT2的embedding是token、position两者embedding之和
1、跟BERT的主要区别是三者相加之后没有加LayerNormalization层。
2、bert的layernorm是在attn/ffc之后,OpenAi-gpt2是在之前。
使用LM_Mask实现预训练ckpt中的bias参数,最后的全连接层由于和embedding层权重一致,因此直接从word_embedding取
"""
super
(
GPT2
,
self
).
__init__
(
max_position
,
**
kwargs
)
del
self
.
embeddings
.
layerNorm
layer
=
self
.
Gpt2Layer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
)
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
if
layer_id
in
self
.
keep_hidden_layers
else
Identity
()
for
layer_id
in
range
(
self
.
num_hidden_layers
)])
self
.
LayerNormFinal
=
LayerNorm
(
self
.
hidden_size
,
eps
=
1e-12
,
conditional_size
=
self
.
conditional_size
)
self
.
dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
False
)
self
.
dense
.
weight
=
self
.
embeddings
.
word_embeddings
.
weight
self
.
final_activation
=
get_activation
(
self
.
final_activation
)
def
apply_final_layers
(
self
,
inputs
):
hidden_state
=
super
().
apply_final_layers
(
inputs
)
logit
=
self
.
dense
(
self
.
LayerNormFinal
([
hidden_state
]))
return
self
.
final_activation
(
logit
)
def
load_variable
(
self
,
state_dict
,
name
):
return
super
(
GPT2
,
self
).
load_variable
(
state_dict
,
name
,
prefix
=
'gpt2'
)
def
variable_mapping
(
self
):
"""映射到GPT权重格式
"""
mapping
=
super
(
GPT2
,
self
).
variable_mapping
(
prefix
=
'gpt2'
)
mapping
.
update
({
'LayerNormFinal.weight'
:
'gpt2.LayerNormFinal.weight'
,
'LayerNormFinal.bias'
:
'gpt2.LayerNormFinal.bias'
})
return
mapping
class
Gpt2Layer
(
BertLayer
):
'''未定义在layer.py中是因为该层针对gpt2_mlm模型,不可复用
顺序:LN --> Att --> Add --> LN --> FFN --> Add
'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
conditional_emb
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
# bert的layernorm是在attn/ffc之后,Openai-gpt2是在之前
x
=
self
.
layerNorm1
((
hidden_states
,
conditional_emb
))
self_attn_output
=
self
.
multiHeadAttention
(
x
,
attention_mask
)
hidden_states
=
hidden_states
+
self
.
dropout1
(
self_attn_output
)
x
=
self
.
layerNorm2
((
hidden_states
,
conditional_emb
))
ffn_output
=
self
.
feedForward
(
x
)
hidden_states
=
hidden_states
+
self
.
dropout2
(
ffn_output
)
return
hidden_states
class
GPT2_ML
(
LM_Mask
,
BERT
):
"""构建GPT2_ML模型
链接: https://github.com/imcaspar/gpt2-ml
注意:GPT2_ML虽然号称GPT2,但是它的结构其实更接近GPT,它自称GPT2的原因大概是因为它开源的版本参数量达到了GPT2的15亿参数。
看完ckpt中的key,和GPT的区别是embedding后也有layernorm,和bert的区别是第一个跳跃链接是在layernorm前,bert是在之后
"""
@
insert_arguments
(
final_activation
=
'softmax'
)
@
delete_arguments
(
'with_pool'
,
'with_mlm'
,
'with_nsp'
)
def
__init__
(
self
,
max_position
,
**
kwargs
):
super
().
__init__
(
max_position
,
**
kwargs
)
layer
=
self
.
Gpt2MlLayer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
)
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
if
layer_id
in
self
.
keep_hidden_layers
else
Identity
()
for
layer_id
in
range
(
self
.
num_hidden_layers
)])
self
.
dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
False
)
self
.
dense
.
weight
=
self
.
embeddings
.
word_embeddings
.
weight
self
.
final_activation
=
get_activation
(
self
.
final_activation
)
def
apply_final_layers
(
self
,
inputs
):
hidden_state
=
super
().
apply_final_layers
(
inputs
)
logit
=
self
.
dense
(
hidden_state
)
return
self
.
final_activation
(
logit
)
def
load_variable
(
self
,
state_dict
,
name
):
return
super
(
GPT2_ML
,
self
).
load_variable
(
state_dict
,
name
,
prefix
=
'gpt2_ml'
)
def
variable_mapping
(
self
):
"""映射到GPT2权重格式
"""
mapping
=
super
(
GPT2_ML
,
self
).
variable_mapping
(
prefix
=
'gpt2_ml'
)
return
mapping
class
Gpt2MlLayer
(
BertLayer
):
'''未定义在layer.py中是因为该层针对gpt2_mlm模型,不可复用
顺序:Att --> Add --> LN --> FFN --> Add --> LN
'''
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
conditional_emb
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
):
self_attn_output
=
self
.
multiHeadAttention
(
hidden_states
,
attention_mask
)
hidden_states
=
hidden_states
+
self
.
dropout1
(
self_attn_output
)
x
=
self
.
layerNorm1
((
hidden_states
,
conditional_emb
))
# bert的跳跃连接是在layerNorm之后,gpt2_ml是在layerNorm之前
ffn_output
=
self
.
feedForward
(
x
)
hidden_states
=
hidden_states
+
self
.
dropout2
(
ffn_output
)
hidden_states
=
self
.
layerNorm2
((
hidden_states
,
conditional_emb
))
return
hidden_states
class
Transformer_XL
(
BERT
):
'''构建transformer-xl模型, 已加载
项目: https://github.com/kimiyoung/transformer-xl
不同点:
1) 简化了原有的AdaptiveEmbedding(可选)和未使用ProjectedAdaptiveLogSoftmax, 直接输出last_hidden_state
2) mems修改了transformer中初始化为zero_tensor, 改为包含最后一层, 原项目初始化为empty_tensor
3) SinusoidalPositionEncoding一般是sincos间隔排列, 这里是先sin后cos
4) attention_mask在multi_attn中使用中使用1e30来替代原来的1000
'''
@
delete_arguments
(
'with_pool'
,
'with_nsp'
,
'with_mlm'
)
@
insert_arguments
(
with_lm
=
False
)
def
__init__
(
self
,
*
args
,
mem_len
=
0
,
same_length
=
False
,
clamp_len
=-
1
,
**
kwargs
):
# p_bias来控制embedding阶段无pos_embedding
kwargs
.
update
({
'p_bias'
:
'other_relative'
})
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
mem_len
,
self
.
same_length
,
self
.
clamp_len
=
mem_len
,
same_length
,
clamp_len
self
.
attn_type
=
kwargs
.
get
(
'attn_type'
,
0
)
# embedding
if
kwargs
.
get
(
'adaptive_embedding'
):
cutoffs
,
div_val
,
sample_softmax
=
kwargs
.
get
(
'cutoffs'
,
[]),
kwargs
.
get
(
'div_val'
,
1
),
kwargs
.
get
(
'sample_softmax'
,
False
)
self
.
embeddings
=
AdaptiveEmbedding
(
self
.
vocab_size
,
self
.
embedding_size
,
self
.
hidden_size
,
cutoffs
,
div_val
,
sample_softmax
,
**
get_kw
(
AdaptiveEmbedding
,
kwargs
))
else
:
self
.
embeddings
=
nn
.
Embedding
(
self
.
vocab_size
,
self
.
embedding_size
)
self
.
pos_embeddings
=
XlnetPositionsEncoding
(
self
.
embedding_size
)
self
.
dropout
=
nn
.
Dropout
(
self
.
dropout_rate
)
# 每层自己的r_w_bias和r_r_bias,还是公用
if
not
kwargs
.
get
(
'untie_r'
):
self
.
r_w_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
num_attention_heads
,
self
.
attention_head_size
))
# 全局内容偏置
self
.
r_r_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
num_attention_heads
,
self
.
attention_head_size
))
# 全局位置偏置
if
self
.
segment_vocab_size
>
0
:
self
.
r_s_bias
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
num_attention_heads
,
self
.
attention_head_size
))
# 全局segment偏置
else
:
self
.
r_w_bias
,
self
.
r_r_bias
=
None
,
None
self
.
r_s_bias
=
None
# transformer block
layer
=
XlnetLayer
(
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
dropout_rate
,
self
.
attention_probs_dropout_prob
,
self
.
intermediate_size
,
self
.
hidden_act
,
is_dropout
=
self
.
is_dropout
,
conditional_size
=
self
.
conditional_size
,
r_w_bias
=
self
.
r_w_bias
,
r_r_bias
=
self
.
r_r_bias
,
r_s_bias
=
None
,
**
get_kw
(
BertLayer
,
kwargs
))
self
.
encoderLayer
=
nn
.
ModuleList
([
copy
.
deepcopy
(
layer
)
if
layer_id
in
self
.
keep_hidden_layers
else
Identity
()
for
layer_id
in
range
(
self
.
num_hidden_layers
)])
# 映射
if
self
.
with_lm
:
self
.
dense
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
vocab_size
,
bias
=
True
)
def
init_mems
(
self
,
bsz
):
'''初始化mems, 用于记忆mlen的各层隐含层状态
'''
if
isinstance
(
self
.
mem_len
,
(
int
,
float
))
and
(
self
.
mem_len
>
0
):
mems
=
[]
param
=
next
(
self
.
parameters
())
for
_
in
range
(
self
.
num_hidden_layers
+
1
):
empty
=
torch
.
zeros
(
bsz
,
self
.
mem_len
,
self
.
hidden_size
,
dtype
=
param
.
dtype
,
device
=
param
.
device
)
mems
.
append
(
empty
)
return
mems
else
:
return
None
def
_update_mems
(
self
,
hids
,
mlen
,
qlen
):
'''更新mems
'''
# does not deal with None
if
self
.
mems
is
None
:
return
None
# mems is not None
assert
len
(
hids
)
==
len
(
self
.
mems
),
"len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
with
torch
.
no_grad
():
new_mems
=
[]
end_idx
=
mlen
+
max
(
0
,
qlen
)
beg_idx
=
max
(
0
,
end_idx
-
self
.
mem_len
)
for
i
in
range
(
len
(
hids
)):
cat
=
torch
.
cat
([
self
.
mems
[
i
],
hids
[
i
]],
dim
=
1
)
new_mems
.
append
(
cat
[:,
beg_idx
:
end_idx
].
detach
())
self
.
mems
=
new_mems
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
device
):
# 生成pos_emb, 这里使用sincos的位置编码,为了和xlnet入参一致
pos_seq
=
torch
.
arange
(
klen
-
1
,
-
1
,
-
1.0
,
device
=
device
,
dtype
=
torch
.
long
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
pos_emb
=
self
.
dropout
(
self
.
pos_embeddings
(
pos_seq
))
# 用word_emb的dropout
return
pos_emb
def
create_mask
(
self
,
word_emb
,
qlen
,
klen
,
mlen
):
# 修改attention_mask, mlen可以全部访问,q_len只能访问<=t时刻的, mask和Unilm类似,但是Unilm是靠segement_ids来控制
if
self
.
same_length
:
# 只能访问前面固定长度
all_ones
=
word_emb
.
new_ones
(
qlen
,
klen
)
mask_len
=
klen
-
self
.
mem_len
mask_shift_len
=
qlen
-
mask_len
if
mask_len
>
0
else
qlen
attention_mask
=
1
-
(
torch
.
triu
(
all_ones
,
1
+
mlen
)
+
torch
.
tril
(
all_ones
,
-
mask_shift_len
)).
byte
()
# -1
else
:
attention_mask
=
torch
.
tril
(
word_emb
.
new_ones
(
qlen
,
klen
),
diagonal
=
mlen
).
byte
()
# [q_len, k_len], 下三角为1矩阵
attention_mask
=
attention_mask
[
None
,
None
,
:,
:]
return
attention_mask
def
apply_embeddings
(
self
,
inputs
):
'''接受的inputs输入: [token_ids, segment_ids], 暂不支持条件LayerNorm输入
'''
self
.
mems
=
self
.
init_mems
(
inputs
[
0
].
size
(
0
))
# 生成mems
# 精简后embeddings中只计算word_emdedding
word_emb
=
self
.
dropout
(
self
.
embeddings
(
inputs
[
0
]))
index_
=
1
btz
,
qlen
=
inputs
[
0
].
shape
[:
2
]
# query长度
mlen
=
self
.
mems
[
0
].
size
(
1
)
if
self
.
mems
is
not
None
else
0
klen
=
mlen
+
qlen
# 相对位置编码
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
word_emb
.
device
)
# segment embedding
if
self
.
segment_vocab_size
>
0
:
segment_ids
=
inputs
[
index_
]
if
mlen
>
0
:
mem_pad
=
torch
.
zeros
([
btz
,
mlen
],
dtype
=
torch
.
long
,
device
=
word_emb
.
device
)
cat_ids
=
torch
.
cat
([
mem_pad
,
segment_ids
],
dim
=
1
)
else
:
cat_ids
=
segment_ids
# `1` indicates not in the same segment [qlen x klen x bsz]
segment_ids
=
(
segment_ids
[:,
:,
None
]
!=
cat_ids
[:,
None
]).
long
()
index_
+=
1
else
:
segment_ids
=
None
if
self
.
attn_type
in
{
'uni'
,
0
}:
# 兼容transformer_xl的设置: 0
attention_mask
=
self
.
create_mask
(
word_emb
,
qlen
,
klen
,
mlen
)
elif
self
.
attn_type
==
'bi'
:
attention_mask
=
(
inputs
[
0
]
!=
self
.
token_pad_ids
).
long
().
unsqueeze
(
1
).
unsqueeze
(
2
)
non_tgt_mask
=
torch
.
eye
(
qlen
).
to
(
attention_mask
)[
None
,
None
,
:,
:]
non_tgt_mask
=
((
1
-
attention_mask
-
non_tgt_mask
)
<=
0
).
long
()
return
[
word_emb
,
segment_ids
,
pos_emb
,
non_tgt_mask
,
None
]
def
apply_main_layers
(
self
,
inputs
):
hidden_states
,
segment_ids
,
pos_emb
,
attention_mask
,
conditional_emb
=
inputs
[:
5
]
encoded_layers
=
[
hidden_states
]
# 添加embedding的输出
layer_inputs
=
[
hidden_states
,
segment_ids
,
pos_emb
,
attention_mask
,
None
,
conditional_emb
]
for
i
,
layer_module
in
enumerate
(
self
.
encoderLayer
):
mems_i
=
None
if
self
.
mems
is
None
else
self
.
mems
[
i
]
layer_inputs
[
-
2
]
=
mems_i
layer_inputs
=
self
.
apply_on_layer_begin
(
i
,
layer_inputs
)
hidden_states
=
layer_module
(
*
layer_inputs
)
layer_inputs
[
0
]
=
hidden_states
layer_inputs
=
self
.
apply_on_layer_end
(
i
,
layer_inputs
)
encoded_layers
.
append
(
hidden_states
)
# 原实现中word_emb, pos_emb和core_out(hidden_states)使用同一个dropout
hidden_states
=
self
.
dropout
(
hidden_states
)
qlen
=
inputs
[
0
].
size
(
1
)
# query长度
mlen
=
self
.
mems
[
0
].
size
(
0
)
if
self
.
mems
is
not
None
else
0
self
.
_update_mems
(
encoded_layers
,
mlen
,
qlen
)
if
not
self
.
output_all_encoded_layers
:
# 不返回所有层,即返回顶层
encoded_layers
=
encoded_layers
[:
1
]
+
[
hidden_states
]
return
[
encoded_layers
,
conditional_emb
]
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
''
):
# 这里由于预训练模型使用了AdapterEmbedding,因此暂不支持
if
(
self
.
keep_tokens
is
not
None
)
or
(
self
.
compound_tokens
is
not
None
):
raise
ValueError
(
'Custom keep_tokens and compound_tokens is not yet supported in Transformer_XL'
)
return
state_dict
[
name
]
def
variable_mapping
(
self
,
prefix
=
''
):
return
{
k
:
k
for
k
,
v
in
self
.
named_parameters
()}
class
XLNET
(
Transformer_XL
):
'''构建xlnet模型, 这里做了简化, 只用来finetune, 即没有perm_mask, target_mapping这些输入
接受的inputs输入: [token_ids, segment_ids]
'''
def
__init__
(
self
,
*
args
,
bi_data
=
False
,
**
kwargs
):
self
.
attn_type
=
kwargs
.
get
(
'attn_type'
,
'bi'
)
self
.
bi_data
=
bi_data
kwargs
[
'rel_shift_opt'
]
=
'xlnet'
super
().
__init__
(
*
args
,
**
kwargs
)
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
device
):
# 生成pos_emb, 这里使用sincos的位置编码, transformer_xl里面有-1
if
self
.
attn_type
==
'bi'
:
beg
,
end
=
klen
,
-
qlen
elif
self
.
attn_type
==
"uni"
:
beg
,
end
=
klen
,
-
1
else
:
raise
ValueError
(
f
"Unknown `attn_type`
{
self
.
attn_type
}
."
)
# 前向的emb
pos_seq
=
torch
.
arange
(
beg
,
end
,
-
1.0
,
device
=
device
,
dtype
=
torch
.
long
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
fwd_pos_emb
=
self
.
pos_embeddings
(
pos_seq
)
# 双向数据
if
self
.
bi_data
:
pos_seq
=
torch
.
arange
(
-
beg
,
-
end
,
-
1.0
,
device
=
device
,
dtype
=
torch
.
long
)
if
self
.
clamp_len
>
0
:
pos_seq
.
clamp_
(
max
=
self
.
clamp_len
)
bwd_pos_emb
=
self
.
pos_embeddings
(
pos_seq
)
pos_emb
=
torch
.
cat
([
fwd_pos_emb
,
bwd_pos_emb
],
dim
=
0
)
else
:
pos_emb
=
fwd_pos_emb
pos_emb
=
self
.
dropout
(
pos_emb
)
# 用word_emb的dropout
return
pos_emb
def
apply_final_layers
(
self
,
inputs
):
hidden_state
=
super
().
apply_final_layers
(
inputs
)
if
self
.
with_lm
:
return
[
hidden_state
,
self
.
dense
(
hidden_state
)]
else
:
return
hidden_state
def
load_variable
(
self
,
state_dict
,
name
,
prefix
=
'transformer'
):
"""加载单个变量的函数
"""
variable
=
state_dict
[
name
]
if
name
in
{
f
'
{
prefix
}
.word_embedding.weight'
,
'lm_loss.weight'
,
'lm_loss.bias'
}:
return
self
.
load_embeddings
(
variable
)
elif
re
.
search
(
'rel_attn\.(q|k|v|r)$'
,
name
):
return
variable
.
reshape
(
variable
.
shape
[
0
],
-
1
).
T
# elif re.search('rel_attn\.(o|seg_embed)$', name):
elif
re
.
search
(
'rel_attn\.(o)$'
,
name
):
return
variable
.
reshape
(
variable
.
shape
[
0
],
-
1
)
else
:
return
variable
def
variable_mapping
(
self
,
prefix
=
'transformer'
):
mapping
=
{
'embeddings.weight'
:
f
'
{
prefix
}
.word_embedding.weight'
,
'dense.weight'
:
'lm_loss.weight'
,
'dense.bias'
:
'lm_loss.bias'
,
}
for
i
in
range
(
self
.
num_hidden_layers
):
prefix_i
=
f
'
{
prefix
}
.layer.%d.'
%
i
mapping
.
update
({
f
'encoderLayer.
{
i
}
.multiHeadAttention.q.weight'
:
prefix_i
+
'rel_attn.q'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.k.weight'
:
prefix_i
+
'rel_attn.k'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.v.weight'
:
prefix_i
+
'rel_attn.v'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.o.weight'
:
prefix_i
+
'rel_attn.o'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.r.weight'
:
prefix_i
+
'rel_attn.r'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.r_r_bias'
:
prefix_i
+
'rel_attn.r_r_bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.r_s_bias'
:
prefix_i
+
'rel_attn.r_s_bias'
,
f
'encoderLayer.
{
i
}
.multiHeadAttention.r_w_bias'
:
prefix_i
+
'rel_attn.r_w_bias'
,
# f'encoderLayer.{i}.multiHeadAttention.seg_embed.weight': prefix_i + 'rel_attn.seg_embed',
f
'encoderLayer.
{
i
}
.multiHeadAttention.seg_embed'
:
prefix_i
+
'rel_attn.seg_embed'
,
f
'encoderLayer.
{
i
}
.layerNorm1.weight'
:
prefix_i
+
'rel_attn.layer_norm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm1.bias'
:
prefix_i
+
'rel_attn.layer_norm.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.weight'
:
prefix_i
+
'ff.layer_1.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.intermediateDense.bias'
:
prefix_i
+
'ff.layer_1.bias'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.weight'
:
prefix_i
+
'ff.layer_2.weight'
,
f
'encoderLayer.
{
i
}
.feedForward.outputDense.bias'
:
prefix_i
+
'ff.layer_2.bias'
,
f
'encoderLayer.
{
i
}
.layerNorm2.weight'
:
prefix_i
+
'ff.layer_norm.weight'
,
f
'encoderLayer.
{
i
}
.layerNorm2.bias'
:
prefix_i
+
'ff.layer_norm.bias'
})
return
mapping
def
build_transformer_model
(
config_path
=
None
,
checkpoint_path
=
None
,
model
=
'bert'
,
application
=
'encoder'
,
**
kwargs
):
"""根据配置文件构建模型,可选加载checkpoint权重
"""
configs
=
{}
if
config_path
is
not
None
:
configs
.
update
(
json
.
load
(
open
(
config_path
)))
configs
.
update
(
kwargs
)
if
'max_position'
not
in
configs
:
configs
[
'max_position'
]
=
configs
.
get
(
'max_position_embeddings'
,
512
)
if
'dropout_rate'
not
in
configs
:
configs
[
'dropout_rate'
]
=
configs
.
get
(
'hidden_dropout_prob'
)
if
'segment_vocab_size'
not
in
configs
:
configs
[
'segment_vocab_size'
]
=
configs
.
get
(
'type_vocab_size'
,
2
)
models
=
{
'bert'
:
BERT
,
'roberta'
:
BERT
,
'albert'
:
ALBERT
,
'albert_unshared'
:
ALBERT_Unshared
,
'nezha'
:
NEZHA
,
'roformer'
:
RoFormer
,
'roformer_v2'
:
RoFormerV2
,
'gau_alpha'
:
GAU_alpha
,
'electra'
:
ELECTRA
,
'encoder'
:
Encoder
,
'decoder'
:
Decoder
,
'transformer'
:
Transformer
,
'bart'
:
BART
,
'gpt'
:
GPT
,
'gpt2'
:
GPT2
,
'gpt2_ml'
:
GPT2_ML
,
't5'
:
T5
,
't5_encoder'
:
T5_Encoder
,
't5_decoder'
:
T5_Decoder
,
't5.1.0'
:
T5
,
't5.1.0_encoder'
:
T5_Encoder
,
't5.1.0_decoder'
:
T5_Decoder
,
't5.1.1'
:
T5
,
't5.1.1_encoder'
:
T5_Encoder
,
't5.1.1_decoder'
:
T5_Decoder
,
'mt5.1.1'
:
T5
,
'mt5.1.1_encoder'
:
T5_Encoder
,
'mt5.1.1_decoder'
:
T5_Decoder
,
'transformer_xl'
:
Transformer_XL
,
'xlnet'
:
XLNET
,
}
if
isinstance
(
model
,
str
):
# string表示使用自带的模型
MODEL
=
models
[
model
.
lower
()]
if
model
.
endswith
(
't5.1.1'
):
configs
[
'version'
]
=
model
elif
isinstance
(
model
,
type
)
and
issubclass
(
model
,
BERT_BASE
):
# nn.Module表示使用自定义的模型:
MODEL
=
model
else
:
raise
ValueError
(
'"model" args type should be string or nn.Module'
)
application
=
application
.
lower
()
if
application
in
[
'lm'
,
'unilm'
]
and
model
in
[
'electra'
,
't5'
,
]:
raise
ValueError
(
f
'"
{
model
}
" model can not be used as "
{
application
}
" application.
\n
'
)
if
application
==
'lm'
:
MODEL
=
extend_with_language_model
(
MODEL
)
elif
application
==
'unilm'
:
MODEL
=
extend_with_unified_language_model
(
MODEL
)
transformer
=
MODEL
(
**
configs
)
transformer
.
build
(
**
configs
)
transformer
.
apply
(
transformer
.
init_model_weights
)
# 初始化权重
if
checkpoint_path
is
not
None
:
transformer
.
load_weights_from_pytorch_checkpoint
(
checkpoint_path
)
transformer
.
configs
=
configs
return
transformer
\ No newline at end of file
build/lib/bert4torch/optimizers.py
0 → 100644
View file @
66a1d0d0
from
torch.optim.lr_scheduler
import
LambdaLR
def
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
,
num_training_steps
,
last_epoch
=-
1
):
"""带warmup的schedule, 源自transformers包optimization.py中
参数
num_warmup_steps:
需要warmup的步数, 一般为 num_training_steps * warmup_proportion(warmup的比例, 建议0.05-0.15)
num_training_steps:
总的训练步数, 一般为 train_batches * num_epoch
"""
def
lr_lambda
(
current_step
:
int
):
if
current_step
<
num_warmup_steps
:
return
float
(
current_step
)
/
float
(
max
(
1
,
num_warmup_steps
))
return
max
(
0.0
,
float
(
num_training_steps
-
current_step
)
/
float
(
max
(
1
,
num_training_steps
-
num_warmup_steps
)))
return
LambdaLR
(
optimizer
,
lr_lambda
,
last_epoch
)
def
extend_with_exponential_moving_average
(
model
,
decay
=
0.999
):
class
ExponentialMovingAverage
():
''' 模型权重的指数滑动平均, 不参加梯度更新,只是记录滑动平均的参数,给预测使用
注意区别于类似adam一类的自适应学习率优化器, 针对一阶二阶梯度的指数滑动平均, 两者完全不同
例子:
# 初始化
ema = ExponentialMovingAverage(model, 0.999)
# 训练过程中, 更新完参数后, 同步update ema_weights weights
def train():
optimizer.step()
ema.step()
# eval前, 调用apply_ema_weights(); eval之后, restore_raw_weights()恢复原来模型的参数
def evaluate():
ema.apply_ema_weights()
# evaluate
# 如果想保存ema后的模型, 请在restore方法之前调用torch.save()
ema.restore_raw_weights()
'''
def
__init__
(
self
,
model
,
decay
):
self
.
model
=
model
self
.
decay
=
decay
# 保存ema权重(当前step的每一层的滑动平均权重)
self
.
ema_weights
=
{}
# 在进行evaluate的时候, 保存原始的模型权重, 当执行完evaluate后, 从ema权重恢复到原始权重
self
.
model_weights
=
{}
# 初始化ema_weights为model_weights
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
:
self
.
ema_weights
[
name
]
=
param
.
data
.
clone
()
def
step
(
self
):
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
:
assert
name
in
self
.
ema_weights
new_average
=
(
1.0
-
self
.
decay
)
*
param
.
data
+
self
.
decay
*
self
.
ema_weights
[
name
]
self
.
ema_weights
[
name
]
=
new_average
.
clone
()
def
apply_ema_weights
(
self
):
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
:
assert
name
in
self
.
ema_weights
self
.
model_weights
[
name
]
=
param
.
data
param
.
data
=
self
.
ema_weights
[
name
]
def
restore_raw_weights
(
self
):
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
:
assert
name
in
self
.
model_weights
param
.
data
=
self
.
model_weights
[
name
]
self
.
model_weights
=
{}
return
ExponentialMovingAverage
(
model
,
decay
)
\ No newline at end of file
build/lib/bert4torch/snippets.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 其他代码合
import
unicodedata
import
six
import
numpy
as
np
import
re
import
torch
from
torch.nn.utils.rnn
import
pad_sequence
import
time
import
sys
import
collections
import
torch.nn
as
nn
from
torch.utils.data
import
Dataset
,
IterableDataset
import
math
import
gc
import
inspect
import
json
import
torch.nn.functional
as
F
import
random
import
warnings
import
os
is_py2
=
six
.
PY2
if
not
is_py2
:
basestring
=
str
def
take_along_dim
(
input_tensor
,
indices
,
dim
=
None
):
'''兼容部分低版本pytorch没有torch.take_along_dim
'''
if
torch
.
__version__
>=
'1.9.0'
:
return
torch
.
take_along_dim
(
input_tensor
,
indices
,
dim
)
else
:
# 该逻辑仅在少量数据上测试,如有bug,欢迎反馈
if
dim
is
None
:
res
=
input_tensor
.
flatten
()[
indices
]
else
:
res
=
np
.
take_along_axis
(
input_tensor
.
cpu
().
numpy
(),
indices
.
cpu
().
numpy
(),
axis
=
dim
)
res
=
torch
.
from_numpy
(
res
).
to
(
input_tensor
.
device
)
# assert res.equal(torch.take_along_dim(input_tensor, indices, dim))
return
res
def
is_string
(
s
):
"""判断是否是字符串
"""
return
isinstance
(
s
,
basestring
)
def
truncate_sequences
(
maxlen
,
indices
,
*
sequences
):
"""截断总长度至不超过maxlen
"""
sequences
=
[
s
for
s
in
sequences
if
s
]
if
not
isinstance
(
indices
,
(
list
,
tuple
)):
indices
=
[
indices
]
*
len
(
sequences
)
while
True
:
lengths
=
[
len
(
s
)
for
s
in
sequences
]
if
sum
(
lengths
)
>
maxlen
:
i
=
np
.
argmax
(
lengths
)
sequences
[
i
].
pop
(
indices
[
i
])
else
:
return
sequences
def
text_segmentate
(
text
,
maxlen
,
seps
=
'
\n
'
,
strips
=
None
,
truncate
=
True
):
"""将文本按照标点符号划分为若干个短句
truncate: True表示标点符号切分后仍然超长时, 按照maxlen硬截断分成若干个短句
"""
text
=
text
.
strip
().
strip
(
strips
)
if
seps
and
len
(
text
)
>
maxlen
:
pieces
=
text
.
split
(
seps
[
0
])
text
,
texts
=
''
,
[]
for
i
,
p
in
enumerate
(
pieces
):
if
text
and
p
and
len
(
text
)
+
len
(
p
)
>
maxlen
-
1
:
texts
.
extend
(
text_segmentate
(
text
,
maxlen
,
seps
[
1
:],
strips
,
truncate
))
text
=
''
if
i
+
1
==
len
(
pieces
):
text
=
text
+
p
else
:
text
=
text
+
p
+
seps
[
0
]
if
text
:
texts
.
extend
(
text_segmentate
(
text
,
maxlen
,
seps
[
1
:],
strips
,
truncate
))
return
texts
elif
truncate
and
(
not
seps
)
and
(
len
(
text
)
>
maxlen
):
# 标点符号用完,仍然超长,且设置了truncate=True
return
[
text
[
i
*
maxlen
:(
i
+
1
)
*
maxlen
]
for
i
in
range
(
0
,
int
(
np
.
ceil
(
len
(
text
)
/
maxlen
)))]
else
:
return
[
text
]
def
merge_segmentate
(
sequences
,
maxlen
,
sep
=
''
):
'''把m个句子合并成不超过maxlen的n个句子, 主要用途是合并碎句子
'''
sequences_new
=
[]
text
=
''
for
t
in
sequences
:
if
text
and
len
(
text
+
sep
+
t
)
<=
maxlen
:
text
=
text
+
sep
+
t
elif
text
:
sequences_new
.
append
(
text
)
text
=
t
elif
len
(
t
)
<
maxlen
:
# text为空
text
=
t
else
:
sequences_new
.
append
(
t
)
text
=
''
if
text
:
sequences_new
.
append
(
text
)
return
sequences_new
def
text_augmentation
(
texts
,
noise_dict
=
None
,
noise_len
=
0
,
noise_p
=
0.0
,
skip_words
=
None
,
strategy
=
'random'
,
allow_dup
=
True
):
'''简单的EDA策略, 增删改
texts: 需要增强的文本/文本list
noise_dict: 噪音数据, 元素为str的list, tuple, set
noise_len: 噪音长度, 优先试用
noise_p: 噪音比例
skip_words: 跳过的短语, string/list
strategy: 修改的策略, 包含增insert, 删delete, 改replace, 随机random
allow_dup: 是否允许同一个位置多次EDA
'''
def
insert
(
text
,
insert_idx
,
noise_dict
):
text
=
list
(
text
)
for
i
in
insert_idx
:
text
[
i
]
=
text
[
i
]
+
random
.
choice
(
noise_dict
)
return
''
.
join
(
text
)
def
delete
(
text
,
delete_idx
):
text
=
list
(
text
)
for
i
in
delete_idx
:
text
[
i
]
=
''
return
''
.
join
(
text
)
def
replace
(
text
,
replace_idx
,
noise_dict
):
text
=
list
(
text
)
for
i
in
replace_idx
:
text
[
i
]
=
random
.
choice
(
noise_dict
)
return
''
.
join
(
text
)
def
search
(
pattern
,
sequence
,
keep_last
=
True
):
"""从sequence中寻找子串pattern, 返回符合pattern的id集合
"""
n
=
len
(
pattern
)
pattern_idx_set
=
set
()
for
i
in
range
(
len
(
sequence
)):
if
sequence
[
i
:
i
+
n
]
==
pattern
:
pattern_idx_set
=
pattern_idx_set
.
union
(
set
(
range
(
i
,
i
+
n
)))
if
keep_last
else
pattern_idx_set
.
union
(
set
(
range
(
i
,
i
+
n
-
1
)))
return
pattern_idx_set
if
(
noise_len
==
0
)
and
(
noise_p
==
0
):
return
texts
assert
strategy
in
{
'insert'
,
'delete'
,
'replace'
,
'random'
},
'EDA strategy only support insert, delete, replace, random'
if
isinstance
(
texts
,
str
):
texts
=
[
texts
]
if
skip_words
is
None
:
skip_words
=
[]
elif
isinstance
(
skip_words
,
str
):
skip_words
=
[
skip_words
]
for
id
,
text
in
enumerate
(
texts
):
sel_len
=
noise_len
if
noise_len
>
0
else
int
(
len
(
text
)
*
noise_p
)
# 噪声长度
skip_idx
=
set
()
# 不能修改的idx区间
for
item
in
skip_words
:
# insert时最后一位允许插入
skip_idx
=
skip_idx
.
union
(
search
(
item
,
text
,
strategy
!=
'insert'
))
sel_idxs
=
[
i
for
i
in
range
(
len
(
text
))
if
i
not
in
skip_idx
]
# 可供选择的idx区间
sel_len
=
sel_len
if
allow_dup
else
min
(
sel_len
,
len
(
sel_idxs
))
# 无重复抽样需要抽样数小于总样本
if
(
sel_len
==
0
)
or
(
len
(
sel_idxs
)
==
0
):
# 如果不可采样则跳过
continue
sel_idx
=
np
.
random
.
choice
(
sel_idxs
,
sel_len
,
replace
=
allow_dup
)
if
strategy
==
'insert'
:
texts
[
id
]
=
insert
(
text
,
sel_idx
,
noise_dict
)
elif
strategy
==
'delete'
:
texts
[
id
]
=
delete
(
text
,
sel_idx
)
elif
strategy
==
'replace'
:
texts
[
id
]
=
replace
(
text
,
sel_idx
,
noise_dict
)
elif
strategy
==
'random'
:
if
random
.
random
()
<
0.333
:
skip_idx
=
set
()
# 不能修改的idx区间
for
item
in
skip_words
:
# insert时最后一位允许插入
skip_idx
=
skip_idx
.
union
(
search
(
item
,
text
,
keep_last
=
False
))
texts
[
id
]
=
insert
(
text
,
sel_idx
,
noise_dict
)
elif
random
.
random
()
<
0.667
:
texts
[
id
]
=
delete
(
text
,
sel_idx
)
else
:
texts
[
id
]
=
replace
(
text
,
sel_idx
,
noise_dict
)
return
texts
if
len
(
texts
)
>
1
else
texts
[
0
]
def
lowercase_and_normalize
(
text
,
never_split
=
()):
"""转小写,并进行简单的标准化
"""
if
is_py2
:
text
=
unicode
(
text
)
# convert non-special tokens to lowercase
escaped_special_toks
=
[
re
.
escape
(
s_tok
)
for
s_tok
in
never_split
]
pattern
=
r
"("
+
r
"|"
.
join
(
escaped_special_toks
)
+
r
")|"
+
r
"(.+?)"
text
=
re
.
sub
(
pattern
,
lambda
m
:
m
.
groups
()[
0
]
or
m
.
groups
()[
1
].
lower
(),
text
)
# text = text.lower()
text
=
unicodedata
.
normalize
(
'NFD'
,
text
)
text
=
''
.
join
([
ch
for
ch
in
text
if
unicodedata
.
category
(
ch
)
!=
'Mn'
])
return
text
def
sequence_padding
(
inputs
,
length
=
None
,
value
=
0
,
seq_dims
=
1
,
mode
=
'post'
):
"""将序列padding到同一长度
"""
if
isinstance
(
inputs
[
0
],
(
np
.
ndarray
,
list
)):
if
length
is
None
:
length
=
np
.
max
([
np
.
shape
(
x
)[:
seq_dims
]
for
x
in
inputs
],
axis
=
0
)
elif
not
hasattr
(
length
,
'__getitem__'
):
length
=
[
length
]
slices
=
[
np
.
s_
[:
length
[
i
]]
for
i
in
range
(
seq_dims
)]
slices
=
tuple
(
slices
)
if
len
(
slices
)
>
1
else
slices
[
0
]
pad_width
=
[(
0
,
0
)
for
_
in
np
.
shape
(
inputs
[
0
])]
outputs
=
[]
for
x
in
inputs
:
x
=
x
[
slices
]
for
i
in
range
(
seq_dims
):
if
mode
==
'post'
:
pad_width
[
i
]
=
(
0
,
length
[
i
]
-
np
.
shape
(
x
)[
i
])
elif
mode
==
'pre'
:
pad_width
[
i
]
=
(
length
[
i
]
-
np
.
shape
(
x
)[
i
],
0
)
else
:
raise
ValueError
(
'"mode" argument must be "post" or "pre".'
)
x
=
np
.
pad
(
x
,
pad_width
,
'constant'
,
constant_values
=
value
)
outputs
.
append
(
x
)
return
np
.
array
(
outputs
)
elif
isinstance
(
inputs
[
0
],
torch
.
Tensor
):
assert
mode
==
'post'
,
'"mode" argument must be "post" when element is torch.Tensor'
if
length
is
not
None
:
inputs
=
[
i
[:
length
]
for
i
in
inputs
]
return
pad_sequence
(
inputs
,
padding_value
=
value
,
batch_first
=
True
)
else
:
raise
ValueError
(
'"input" argument must be tensor/list/ndarray.'
)
def
insert_arguments
(
**
arguments
):
"""装饰器,为类方法增加参数(主要用于类的__init__方法)
"""
def
actual_decorator
(
func
):
def
new_func
(
self
,
*
args
,
**
kwargs
):
for
k
,
v
in
arguments
.
items
():
if
k
in
kwargs
:
v
=
kwargs
.
pop
(
k
)
setattr
(
self
,
k
,
v
)
return
func
(
self
,
*
args
,
**
kwargs
)
return
new_func
return
actual_decorator
def
delete_arguments
(
*
arguments
):
"""装饰器,为类方法删除参数(主要用于类的__init__方法)
"""
def
actual_decorator
(
func
):
def
new_func
(
self
,
*
args
,
**
kwargs
):
for
k
in
arguments
:
if
k
in
kwargs
:
raise
TypeError
(
'%s got an unexpected keyword argument
\'
%s
\'
'
%
(
self
.
__class__
.
__name__
,
k
)
)
return
func
(
self
,
*
args
,
**
kwargs
)
return
new_func
return
actual_decorator
class
Progbar
(
object
):
"""Displays a progress bar.
# Arguments
target: Total number of steps expected, None if unknown.
width: Progress bar width on screen.
verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over time. Metrics in this list
will be displayed as-is. All others will be averaged
by the progbar before display.
interval: Minimum visual progress update interval (in seconds).
"""
def
__init__
(
self
,
target
,
width
=
30
,
verbose
=
1
,
interval
=
0.05
,
stateful_metrics
=
None
):
self
.
target
=
target
self
.
width
=
width
self
.
verbose
=
verbose
self
.
interval
=
interval
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
_dynamic_display
=
((
hasattr
(
sys
.
stdout
,
'isatty'
)
and
sys
.
stdout
.
isatty
())
or
'ipykernel'
in
sys
.
modules
)
self
.
_total_width
=
0
self
.
_seen_so_far
=
0
self
.
_values
=
collections
.
OrderedDict
()
self
.
_start
=
time
.
time
()
self
.
_last_update
=
0
def
update
(
self
,
current
,
values
=
None
):
"""Updates the progress bar.
# Arguments
current: Index of current step.
values: List of tuples:
`(name, value_for_last_step)`.
If `name` is in `stateful_metrics`,
`value_for_last_step` will be displayed as-is.
Else, an average of the metric over time will be displayed.
"""
values
=
values
or
[]
for
k
,
v
in
values
:
if
k
not
in
self
.
stateful_metrics
:
if
k
not
in
self
.
_values
:
self
.
_values
[
k
]
=
[
v
*
(
current
-
self
.
_seen_so_far
),
current
-
self
.
_seen_so_far
]
else
:
self
.
_values
[
k
][
0
]
+=
v
*
(
current
-
self
.
_seen_so_far
)
self
.
_values
[
k
][
1
]
+=
(
current
-
self
.
_seen_so_far
)
else
:
# Stateful metrics output a numeric value. This representation
# means "take an average from a single value" but keeps the
# numeric formatting.
self
.
_values
[
k
]
=
[
v
,
1
]
self
.
_seen_so_far
=
current
now
=
time
.
time
()
info
=
' - %.0fs'
%
(
now
-
self
.
_start
)
if
self
.
verbose
==
1
:
if
(
now
-
self
.
_last_update
<
self
.
interval
and
self
.
target
is
not
None
and
current
<
self
.
target
):
return
prev_total_width
=
self
.
_total_width
if
self
.
_dynamic_display
:
sys
.
stdout
.
write
(
'
\b
'
*
prev_total_width
)
sys
.
stdout
.
write
(
'
\r
'
)
else
:
sys
.
stdout
.
write
(
'
\n
'
)
if
self
.
target
is
not
None
:
numdigits
=
int
(
np
.
floor
(
np
.
log10
(
self
.
target
)))
+
1
barstr
=
'%%%dd/%d ['
%
(
numdigits
,
self
.
target
)
bar
=
barstr
%
current
prog
=
float
(
current
)
/
self
.
target
prog_width
=
int
(
self
.
width
*
prog
)
if
prog_width
>
0
:
bar
+=
(
'='
*
(
prog_width
-
1
))
if
current
<
self
.
target
:
bar
+=
'>'
else
:
bar
+=
'='
bar
+=
(
'.'
*
(
self
.
width
-
prog_width
))
bar
+=
']'
else
:
bar
=
'%7d/Unknown'
%
current
self
.
_total_width
=
len
(
bar
)
sys
.
stdout
.
write
(
bar
)
if
current
:
time_per_unit
=
(
now
-
self
.
_start
)
/
current
else
:
time_per_unit
=
0
if
self
.
target
is
not
None
and
current
<
self
.
target
:
eta
=
time_per_unit
*
(
self
.
target
-
current
)
if
eta
>
3600
:
eta_format
=
(
'%d:%02d:%02d'
%
(
eta
//
3600
,
(
eta
%
3600
)
//
60
,
eta
%
60
))
elif
eta
>
60
:
eta_format
=
'%d:%02d'
%
(
eta
//
60
,
eta
%
60
)
else
:
eta_format
=
'%ds'
%
eta
info
=
' - ETA: %s'
%
eta_format
else
:
if
time_per_unit
>=
1
:
info
+=
' %.0fs/step'
%
time_per_unit
elif
time_per_unit
>=
1e-3
:
info
+=
' %.0fms/step'
%
(
time_per_unit
*
1e3
)
else
:
info
+=
' %.0fus/step'
%
(
time_per_unit
*
1e6
)
for
k
in
self
.
_values
:
info
+=
' - %s:'
%
k
if
isinstance
(
self
.
_values
[
k
],
list
):
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
abs
(
avg
)
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
else
:
info
+=
' %s'
%
self
.
_values
[
k
]
self
.
_total_width
+=
len
(
info
)
if
prev_total_width
>
self
.
_total_width
:
info
+=
(
' '
*
(
prev_total_width
-
self
.
_total_width
))
if
self
.
target
is
not
None
and
current
>=
self
.
target
:
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
elif
self
.
verbose
==
2
:
if
self
.
target
is
None
or
current
>=
self
.
target
:
for
k
in
self
.
_values
:
info
+=
' - %s:'
%
k
avg
=
np
.
mean
(
self
.
_values
[
k
][
0
]
/
max
(
1
,
self
.
_values
[
k
][
1
]))
if
avg
>
1e-3
:
info
+=
' %.4f'
%
avg
else
:
info
+=
' %.4e'
%
avg
info
+=
'
\n
'
sys
.
stdout
.
write
(
info
)
sys
.
stdout
.
flush
()
self
.
_last_update
=
now
def
add
(
self
,
n
,
values
=
None
):
self
.
update
(
self
.
_seen_so_far
+
n
,
values
)
class
Callback
(
object
):
'''Callback基类
'''
def
__init__
(
self
):
pass
def
on_train_begin
(
self
,
logs
=
None
):
pass
def
on_train_end
(
self
,
logs
=
None
):
pass
def
on_epoch_begin
(
self
,
global_step
,
epoch
,
logs
=
None
):
pass
def
on_epoch_end
(
self
,
global_step
,
epoch
,
logs
=
None
):
pass
def
on_batch_begin
(
self
,
global_step
,
batch
,
logs
=
None
):
pass
def
on_batch_end
(
self
,
global_step
,
batch
,
logs
=
None
):
pass
def
on_dataloader_end
(
self
,
logs
=
None
):
pass
class
ProgbarLogger
(
Callback
):
"""Callback that prints metrics to stdout.
# Arguments
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).
# Raises
ValueError: In case of invalid `count_mode`.
"""
def
__init__
(
self
,
epochs
,
steps
,
metrics
,
stateful_metrics
=
None
,
verbose
=
1
):
super
(
ProgbarLogger
,
self
).
__init__
()
if
stateful_metrics
:
self
.
stateful_metrics
=
set
(
stateful_metrics
)
else
:
self
.
stateful_metrics
=
set
()
self
.
params
=
{
'epochs'
:
epochs
,
'steps'
:
steps
,
'verbose'
:
verbose
,
'metrics'
:
metrics
}
self
.
verbose
=
verbose
self
.
epochs
=
epochs
def
add_metrics
(
self
,
metrics
,
add_position
=
None
):
if
add_position
is
None
:
add_position
=
len
(
self
.
params
[
'metrics'
])
if
isinstance
(
metrics
,
str
):
metrics
=
[
metrics
]
add_metrics
=
[]
for
metric
in
metrics
:
if
metric
not
in
self
.
params
[
'metrics'
]:
add_metrics
.
append
(
metric
)
self
.
params
[
'metrics'
]
=
self
.
params
[
'metrics'
][:
add_position
]
+
add_metrics
+
self
.
params
[
'metrics'
][
add_position
:]
def
on_train_begin
(
self
,
logs
=
None
):
if
self
.
verbose
:
print
(
'Start Training'
.
center
(
40
,
'='
))
def
on_epoch_begin
(
self
,
global_step
=
None
,
epoch
=
None
,
logs
=
None
):
if
self
.
verbose
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
target
=
self
.
params
[
'steps'
]
self
.
progbar
=
Progbar
(
target
=
self
.
target
,
verbose
=
self
.
verbose
,
stateful_metrics
=
self
.
stateful_metrics
)
self
.
seen
=
0
def
on_batch_begin
(
self
,
global_step
=
None
,
batch
=
None
,
logs
=
None
):
if
self
.
seen
<
self
.
target
:
self
.
log_values
=
[]
def
on_batch_end
(
self
,
global_step
=
None
,
batch
=
None
,
logs
=
None
):
logs
=
logs
or
{}
self
.
seen
+=
1
for
k
in
self
.
params
[
'metrics'
]:
if
k
in
logs
:
self
.
log_values
.
append
((
k
,
logs
[
k
]))
# Skip progbar update for the last batch;
# will be handled by on_epoch_end.
if
self
.
verbose
and
self
.
seen
<
self
.
target
:
self
.
progbar
.
update
(
self
.
seen
,
self
.
log_values
)
def
on_epoch_end
(
self
,
global_step
=
None
,
epoch
=
None
,
logs
=
None
):
logs
=
logs
or
{}
for
k
in
self
.
params
[
'metrics'
]:
if
k
in
logs
:
self
.
log_values
.
append
((
k
,
logs
[
k
]))
if
self
.
verbose
:
self
.
progbar
.
update
(
self
.
seen
,
self
.
log_values
)
def
on_train_end
(
self
,
logs
=
None
):
if
self
.
verbose
:
print
(
'Finish Training'
.
center
(
40
,
'='
))
class
EarlyStopping
(
Callback
):
'''Stop training策略, 从keras中移植
'''
def
__init__
(
self
,
monitor
=
'loss'
,
min_delta
=
0
,
patience
=
0
,
verbose
=
0
,
mode
=
'auto'
,
baseline
=
None
):
super
(
EarlyStopping
,
self
).
__init__
()
self
.
monitor
=
monitor
self
.
baseline
=
baseline
self
.
patience
=
patience
self
.
verbose
=
verbose
self
.
min_delta
=
min_delta
self
.
wait
=
0
self
.
stopped_epoch
=
0
if
mode
not
in
[
'auto'
,
'min'
,
'max'
]:
warnings
.
warn
(
'EarlyStopping mode %s is unknown, fallback to auto mode.'
%
mode
,
RuntimeWarning
)
mode
=
'auto'
if
mode
==
'min'
:
self
.
monitor_op
=
np
.
less
elif
mode
==
'max'
:
self
.
monitor_op
=
np
.
greater
else
:
self
.
monitor_op
=
np
.
greater
if
'acc'
in
self
.
monitor
else
np
.
less
self
.
min_delta
=
self
.
min_delta
if
self
.
monitor_op
==
np
.
greater
else
-
self
.
min_delta
def
on_train_begin
(
self
,
logs
=
None
):
# Allow instances to be re-used
self
.
wait
=
0
self
.
stopped_epoch
=
0
if
self
.
baseline
is
not
None
:
self
.
best
=
self
.
baseline
else
:
self
.
best
=
np
.
Inf
if
self
.
monitor_op
==
np
.
less
else
-
np
.
Inf
def
on_epoch_end
(
self
,
steps
,
epoch
,
logs
=
None
):
current
=
self
.
get_monitor_value
(
logs
)
if
current
is
None
:
return
if
self
.
monitor_op
(
current
-
self
.
min_delta
,
self
.
best
):
self
.
best
=
current
self
.
wait
=
0
else
:
self
.
wait
+=
1
if
self
.
wait
>=
self
.
patience
:
self
.
stopped_epoch
=
epoch
def
on_train_end
(
self
,
logs
=
None
):
if
self
.
stopped_epoch
>
0
and
self
.
verbose
>
0
:
print
(
f
'Epoch
{
self
.
stopped_epoch
+
1
}
: early stopping
\n
'
)
def
get_monitor_value
(
self
,
logs
):
monitor_value
=
logs
.
get
(
self
.
monitor
)
if
monitor_value
is
None
:
warnings
.
warn
(
'Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s'
%
(
self
.
monitor
,
','
.
join
(
list
(
logs
.
keys
()))),
RuntimeWarning
)
return
monitor_value
def
metric_mapping
(
metric
,
y_pred
,
y_true
):
if
metric
==
'accuracy'
:
if
isinstance
(
y_pred
,
(
list
,
tuple
)):
y_pred
=
y_pred
[
0
]
y_pred
=
torch
.
argmax
(
y_pred
,
dim
=-
1
)
acc
=
torch
.
sum
(
y_pred
.
eq
(
y_true
)).
item
()
/
y_true
.
size
(
0
)
return
acc
return
None
def
softmax
(
x
,
axis
=-
1
):
"""numpy版softmax
"""
x
=
x
-
x
.
max
(
axis
=
axis
,
keepdims
=
True
)
x
=
np
.
exp
(
x
)
return
x
/
x
.
sum
(
axis
=
axis
,
keepdims
=
True
)
class
AutoRegressiveDecoder
(
object
):
"""通用自回归生成模型解码基类
包含beam search和random sample两种策略
"""
def
__init__
(
self
,
start_id
,
end_id
,
maxlen
,
minlen
=
1
,
device
=
'cpu'
):
self
.
start_id
=
start_id
self
.
end_id
=
end_id
self
.
maxlen
=
maxlen
self
.
minlen
=
minlen
self
.
models
=
{}
self
.
device
=
device
if
start_id
is
None
:
self
.
first_output_ids
=
torch
.
empty
((
1
,
0
),
dtype
=
int
,
device
=
device
)
else
:
self
.
first_output_ids
=
torch
.
tensor
([[
self
.
start_id
]],
device
=
device
)
@
staticmethod
def
wraps
(
default_rtype
=
'probas'
,
use_states
=
False
):
"""用来进一步完善predict函数
目前包含: 1. 设置rtype参数,并做相应处理;
2. 确定states的使用,并做相应处理;
3. 设置温度参数,并做相应处理。
"""
def
actual_decorator
(
predict
):
def
new_predict
(
self
,
inputs
,
output_ids
,
states
,
temperature
=
1
,
rtype
=
default_rtype
):
assert
rtype
in
[
'probas'
,
'logits'
]
prediction
=
predict
(
self
,
inputs
,
output_ids
,
states
)
if
not
use_states
:
prediction
=
(
prediction
,
None
)
if
default_rtype
==
'logits'
:
prediction
=
(
nn
.
Softmax
(
dim
=-
1
)(
prediction
[
0
]
/
temperature
),
prediction
[
1
])
elif
temperature
!=
1
:
probas
=
torch
.
power
(
prediction
[
0
],
1.0
/
temperature
)
probas
=
probas
/
probas
.
sum
(
axis
=-
1
,
keepdims
=
True
)
prediction
=
(
probas
,
prediction
[
1
])
if
rtype
==
'probas'
:
return
prediction
else
:
return
torch
.
log
(
prediction
[
0
]
+
1e-12
),
prediction
[
1
]
return
new_predict
return
actual_decorator
# def last_token(self, model):
# """创建一个只返回最后一个token输出的新Model
# """
# if model not in self.models:
# outputs = [
# keras.layers.Lambda(lambda x: x[:, -1])(output)
# for output in model.outputs
# ]
# self.models[model] = keras.models.Model(model.inputs, outputs)
# return self.models[model]
def
predict
(
self
,
inputs
,
output_ids
,
states
=
None
):
"""用户需自定义递归预测函数
说明: 定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states,
其中default_rtype为字符串logits或probas,probas时返回归一化的概率,
rtype=logits时则返回softmax前的结果或者概率对数。
返回: 二元组 (得分或概率, states)
"""
raise
NotImplementedError
def
beam_search
(
self
,
inputs_raw
,
topk
,
states
=
None
,
temperature
=
1
,
min_ends
=
1
,
add_btz_dim
=
True
):
"""beam search解码
说明: 这里的topk即beam size;
返回: 最优解码序列。
"""
inputs
=
[]
for
i
in
inputs_raw
:
if
isinstance
(
i
,
torch
.
torch
.
Tensor
):
pass
elif
isinstance
(
i
,
(
list
,
tuple
,
np
.
ndarray
))
and
add_btz_dim
:
i
=
torch
.
tensor
([
i
],
device
=
self
.
device
)
elif
isinstance
(
i
,
(
list
,
tuple
,
np
.
ndarray
))
and
not
add_btz_dim
:
i
=
torch
.
tensor
(
i
,
device
=
self
.
device
)
else
:
raise
ValueError
(
'Beam search inputs ele only support tensor、array、list、tuple'
)
inputs
.
append
(
i
)
output_ids
,
output_scores
=
self
.
first_output_ids
,
torch
.
zeros
(
1
,
device
=
self
.
device
)
for
step
in
range
(
self
.
maxlen
):
scores
,
states
=
self
.
predict
(
inputs
,
output_ids
,
states
,
temperature
,
'logits'
)
# 计算当前得分
if
step
==
0
:
# 第1步预测后将输入重复topk次
inputs
=
[
i
.
repeat
([
topk
]
+
[
1
]
*
(
len
(
i
.
shape
)
-
1
))
for
i
in
inputs
]
scores
=
output_scores
.
reshape
((
-
1
,
1
))
+
scores
# 综合累积得分
indices
=
scores
.
flatten
().
argsort
(
dim
=-
1
,
descending
=
True
)[:
topk
]
# 仅保留topk
indices_1
=
torch
.
div
(
indices
,
scores
.
shape
[
1
],
rounding_mode
=
'trunc'
)
# 行索引
indices_2
=
(
indices
%
scores
.
shape
[
1
]).
reshape
((
-
1
,
1
))
# 列索引
output_ids
=
torch
.
cat
([
output_ids
[
indices_1
],
indices_2
],
1
)
# 更新输出
output_scores
=
take_along_dim
(
scores
,
indices
,
dim
=
None
)
# 更新得分
is_end
=
output_ids
[:,
-
1
]
==
self
.
end_id
# 标记是否以end标记结束
end_counts
=
(
output_ids
==
self
.
end_id
).
sum
(
1
)
# 统计出现的end标记
if
output_ids
.
shape
[
1
]
>=
self
.
minlen
:
# 最短长度判断
best
=
output_scores
.
argmax
()
# 得分最大的那个
if
is_end
[
best
]
and
end_counts
[
best
]
>=
min_ends
:
# 如果已经终止
return
output_ids
[
best
]
# 直接输出
else
:
# 否则,只保留未完成部分
flag
=
~
is_end
|
(
end_counts
<
min_ends
)
# 标记未完成序列
if
not
flag
.
all
():
# 如果有已完成的
inputs
=
[
i
[
flag
]
for
i
in
inputs
]
# 扔掉已完成序列
output_ids
=
output_ids
[
flag
]
# 扔掉已完成序列
output_scores
=
output_scores
[
flag
]
# 扔掉已完成序列
end_counts
=
end_counts
[
flag
]
# 扔掉已完成end计数
topk
=
flag
.
sum
()
# topk相应变化
# 达到长度直接输出
return
output_ids
[
output_scores
.
argmax
()]
def
random_sample
(
self
,
inputs
,
n
,
topk
=
None
,
topp
=
None
,
states
=
None
,
temperature
=
1
,
min_ends
=
1
):
"""随机采样n个结果
说明: 非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp
表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
返回: n个解码序列组成的list。
"""
inputs
=
[
torch
.
tensor
([
i
],
device
=
self
.
device
)
for
i
in
inputs
]
output_ids
=
self
.
first_output_ids
results
=
[]
for
step
in
range
(
self
.
maxlen
):
probas
,
states
=
self
.
predict
(
inputs
,
output_ids
,
states
,
temperature
,
'probas'
)
# 计算当前概率
probas
/=
probas
.
sum
(
dim
=-
1
,
keepdims
=
True
)
# 确保归一化
if
step
==
0
:
# 第1步预测后将结果重复n次
probas
=
probas
.
repeat
([
n
]
+
[
1
]
*
(
len
(
probas
.
shape
)
-
1
))
inputs
=
[
i
.
repeat
([
n
]
+
[
1
]
*
(
len
(
i
.
shape
)
-
1
))
for
i
in
inputs
]
output_ids
=
output_ids
.
repeat
([
n
]
+
[
1
]
*
(
len
(
output_ids
.
shape
)
-
1
))
if
topk
is
not
None
:
k_indices
=
probas
.
argsort
(
dim
=-
1
,
descending
=
True
)[:,
:
topk
]
# 仅保留topk
probas
=
take_along_dim
(
probas
,
k_indices
,
dim
=
1
)
# topk概率
probas
/=
probas
.
sum
(
dim
=
1
,
keepdims
=
True
)
# 重新归一化
if
topp
is
not
None
:
p_indices
=
probas
.
argsort
(
dim
=-
1
,
descending
=
True
)
# 从高到低排序
probas
=
take_along_dim
(
probas
,
p_indices
,
dim
=-
1
)
# 排序概率
cumsum_probas
=
torch
.
cumsum
(
probas
,
dim
=-
1
)
# 累积概率
flag
=
torch
.
roll
(
cumsum_probas
>=
topp
,
1
,
dims
=
1
)
# 标记超过topp的部分
flag
[:,
0
]
=
False
# 结合上面的torch.roll,实现平移一位的效果
probas
[
flag
]
=
0
# 后面的全部置零
probas
/=
probas
.
sum
(
dim
=
1
,
keepdims
=
True
)
# 重新归一化
sample_func
=
lambda
p
:
torch
.
multinomial
(
p
,
1
)
# 按概率采样函数
sample_ids
=
torch
.
stack
([
sample_func
(
p
)
for
p
in
probas
])
sample_ids
=
sample_ids
.
reshape
((
-
1
,
1
))
# 对齐形状
if
topp
is
not
None
:
sample_ids
=
take_along_dim
(
p_indices
,
sample_ids
,
dim
=
1
)
# 对齐原id
if
topk
is
not
None
:
sample_ids
=
take_along_dim
(
k_indices
,
sample_ids
,
dim
=
1
)
# 对齐原id
output_ids
=
torch
.
cat
([
output_ids
,
sample_ids
],
1
)
# 更新输出
is_end
=
output_ids
[:,
-
1
]
==
self
.
end_id
# 标记是否以end标记结束
end_counts
=
(
output_ids
==
self
.
end_id
).
sum
(
1
)
# 统计出现的end标记
if
output_ids
.
shape
[
1
]
>=
self
.
minlen
:
# 最短长度判断
flag
=
is_end
&
(
end_counts
>=
min_ends
)
# 标记已完成序列
if
flag
.
any
():
# 如果有已完成的
for
ids
in
output_ids
[
flag
]:
# 存好已完成序列
results
.
append
(
ids
)
flag
=
(
flag
==
False
)
# 标记未完成序列
inputs
=
[
i
[
flag
]
for
i
in
inputs
]
# 只保留未完成部分输入
output_ids
=
output_ids
[
flag
]
# 只保留未完成部分候选集
end_counts
=
end_counts
[
flag
]
# 只保留未完成部分end计数
if
len
(
output_ids
)
==
0
:
break
# 如果还有未完成序列,直接放入结果
for
ids
in
output_ids
:
results
.
append
(
ids
)
# 返回结果
return
results
def
search_layer
(
model
,
layer_name
,
retrun_first
=
True
):
return_list
=
[]
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
and
layer_name
in
name
:
return_list
.
append
(
param
)
if
len
(
return_list
)
==
0
:
return
None
if
retrun_first
:
return
return_list
[
0
]
else
:
return
return_list
class
ListDataset
(
Dataset
):
def
__init__
(
self
,
file_path
=
None
,
data
=
None
,
**
kwargs
):
self
.
kwargs
=
kwargs
if
isinstance
(
file_path
,
(
str
,
list
)):
self
.
data
=
self
.
load_data
(
file_path
)
elif
isinstance
(
data
,
list
):
self
.
data
=
data
else
:
raise
ValueError
(
'The input args shall be str format file_path / list format dataset'
)
def
__len__
(
self
):
return
len
(
self
.
data
)
def
__getitem__
(
self
,
index
):
return
self
.
data
[
index
]
@
staticmethod
def
load_data
(
file_path
):
return
file_path
class
IterDataset
(
IterableDataset
):
'''流式读取文件
'''
def
__init__
(
self
,
file_path
=
None
,
**
kwargs
):
self
.
kwargs
=
kwargs
if
isinstance
(
file_path
,
(
str
,
list
)):
self
.
file_path
=
file_path
else
:
raise
ValueError
(
'The input args shall be str format file_path / list format dataset'
)
def
__iter__
(
self
):
return
self
.
load_data
(
self
.
file_path
)
@
staticmethod
def
load_data
(
file_path
):
return
file_path
# sinusoid编码
def
get_sinusoid_encoding_table
(
n_position
,
d_hid
,
padding_idx
=
None
):
'''Returns: [seq_len, d_hid]
'''
position
=
torch
.
arange
(
0
,
n_position
,
dtype
=
torch
.
float
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
d_hid
,
2
).
float
()
*
(
-
math
.
log
(
10000.0
)
/
d_hid
))
embeddings_table
=
torch
.
zeros
(
n_position
,
d_hid
)
embeddings_table
[:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
embeddings_table
[:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
return
embeddings_table
# 第二种实现
position_ids
=
torch
.
arange
(
0
,
n_position
).
unsqueeze
(
1
)
position_ids
=
position_ids
.
expand
(
-
1
,
d_hid
)
indices
=
torch
.
arange
(
0
,
d_hid
)
position_ids
=
position_ids
*
torch
.
pow
(
10000
,
-
2
*
torch
.
true_divide
(
torch
.
floor_divide
(
indices
,
2
),
d_hid
))
position_ids
[:,
::
2
]
=
torch
.
sin
(
position_ids
[:,
::
2
])
position_ids
[:,
1
::
2
]
=
torch
.
cos
(
position_ids
[:,
1
::
2
])
return
position_ids
def
cal_ts_num
(
tensor_shape
):
'''查看某个tensor在gc中的数量
'''
cal_num
=
0
for
obj
in
gc
.
get_objects
():
try
:
if
torch
.
is_tensor
(
obj
):
# or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
tensor
=
obj
else
:
continue
if
tensor
.
is_cuda
and
tensor
.
size
()
==
tensor_shape
:
print
(
tensor
.
shape
)
cal_num
+=
1
except
Exception
as
e
:
print
(
'A trivial exception occured: {}'
.
format
(
e
))
print
(
cal_num
)
def
get_kw
(
cls
,
kwargs
):
'''保留排除cls的入参后的kwargs
'''
kwargs_new
=
{}
for
k
in
kwargs
:
if
k
not
in
set
(
inspect
.
getargspec
(
cls
)[
0
]):
kwargs_new
[
k
]
=
kwargs
[
k
]
return
kwargs_new
class
FGM
():
'''对抗训练
'''
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
backup
=
{}
def
attack
(
self
,
epsilon
=
1.
,
emb_name
=
'word_embeddings'
,
**
kwargs
):
# emb_name这个参数要换成你模型中embedding的参数名
# 例如,self.emb = nn.Embedding(5000, 100)
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
and
emb_name
in
name
:
self
.
backup
[
name
]
=
param
.
data
.
clone
()
norm
=
torch
.
norm
(
param
.
grad
)
# 默认为2范数
if
norm
!=
0
and
not
torch
.
isnan
(
norm
):
# nan是为了apex混合精度时:
r_at
=
epsilon
*
param
.
grad
/
norm
param
.
data
.
add_
(
r_at
)
def
restore
(
self
,
emb_name
=
'emb'
,
**
kwargs
):
# emb_name这个参数要换成你模型中embedding的参数名
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
and
emb_name
in
name
:
assert
name
in
self
.
backup
param
.
data
=
self
.
backup
[
name
]
self
.
backup
=
{}
class
PGD
():
'''对抗训练
'''
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
emb_backup
=
{}
self
.
grad_backup
=
{}
def
attack
(
self
,
epsilon
=
1.
,
alpha
=
0.3
,
emb_name
=
'word_embeddings'
,
is_first_attack
=
False
,
**
kwargs
):
# emb_name这个参数要换成你模型中embedding的参数名
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
and
emb_name
in
name
:
if
is_first_attack
:
self
.
emb_backup
[
name
]
=
param
.
data
.
clone
()
norm
=
torch
.
norm
(
param
.
grad
)
if
norm
!=
0
and
not
torch
.
isnan
(
norm
):
# nan是为了apex混合精度时
r_at
=
alpha
*
param
.
grad
/
norm
param
.
data
.
add_
(
r_at
)
param
.
data
=
self
.
project
(
name
,
param
.
data
,
epsilon
)
def
restore
(
self
,
emb_name
=
'emb'
,
**
kwargs
):
# emb_name这个参数要换成你模型中embedding的参数名
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
and
emb_name
in
name
:
assert
name
in
self
.
emb_backup
param
.
data
=
self
.
emb_backup
[
name
]
self
.
emb_backup
=
{}
def
project
(
self
,
param_name
,
param_data
,
epsilon
):
r
=
param_data
-
self
.
emb_backup
[
param_name
]
if
torch
.
norm
(
r
)
>
epsilon
:
r
=
epsilon
*
r
/
torch
.
norm
(
r
)
return
self
.
emb_backup
[
param_name
]
+
r
def
backup_grad
(
self
):
for
name
,
param
in
self
.
model
.
named_parameters
():
# 修复如pooling层参与foward,但是不参与backward过程时grad为空的问题
if
param
.
requires_grad
and
(
param
.
grad
is
not
None
):
self
.
grad_backup
[
name
]
=
param
.
grad
.
clone
()
def
restore_grad
(
self
):
for
name
,
param
in
self
.
model
.
named_parameters
():
if
param
.
requires_grad
and
(
param
.
grad
is
not
None
):
param
.
grad
=
self
.
grad_backup
[
name
]
class
VAT
():
'''虚拟对抗训练 https://github.com/namisan/mt-dnn/blob/v0.2/alum/adv_masked_lm.py
'''
def
__init__
(
self
,
model
,
emb_name
=
'word_embeddings'
,
noise_var
=
1e-5
,
noise_gamma
=
1e-6
,
adv_step_size
=
1e-3
,
adv_alpha
=
1
,
norm_type
=
'l2'
,
**
kwargs
):
self
.
model
=
model
self
.
noise_var
=
noise_var
# 噪声的方差
self
.
noise_gamma
=
noise_gamma
# eps
self
.
adv_step_size
=
adv_step_size
# 学习率
self
.
adv_alpha
=
adv_alpha
# 对抗loss的权重
self
.
norm_type
=
norm_type
# 归一化方式
self
.
embed
=
None
for
(
name
,
module
)
in
self
.
model
.
named_modules
():
if
emb_name
in
name
:
module
.
register_forward_hook
(
hook
=
self
.
hook
)
def
hook
(
self
,
module
,
fea_in
,
fea_out
):
self
.
embed
=
fea_out
return
None
def
forward_
(
self
,
train_X
,
new_embed
):
# 把原来的train_X中的token_ids换成embedding形式
if
isinstance
(
train_X
,
(
tuple
,
list
)):
new_train_X
=
[
new_embed
]
+
train_X
[
1
:]
adv_output
=
self
.
model
.
forward
(
*
new_train_X
)
if
self
.
model
.
forward
.
__code__
.
co_argcount
>=
3
else
self
.
model
.
forward
(
new_train_X
)
elif
isinstance
(
train_X
,
torch
.
Tensor
):
adv_output
=
self
.
model
.
forward
(
new_embed
)
return
adv_output
def
virtual_adversarial_training
(
self
,
train_X
,
logits
):
# 初始扰动 r
noise
=
self
.
embed
.
data
.
new
(
self
.
embed
.
size
()).
normal_
(
0
,
1
)
*
self
.
noise_var
noise
.
requires_grad_
()
# x + r
new_embed
=
self
.
embed
.
data
.
detach
()
+
noise
adv_output
=
self
.
forward_
(
train_X
,
new_embed
)
# forward第一次
adv_logits
=
adv_output
[
0
]
if
isinstance
(
adv_output
,
(
list
,
tuple
))
else
adv_output
adv_loss
=
self
.
kl
(
adv_logits
,
logits
.
detach
(),
reduction
=
"batchmean"
)
delta_grad
,
=
torch
.
autograd
.
grad
(
adv_loss
,
noise
,
only_inputs
=
True
)
norm
=
delta_grad
.
norm
()
# 梯度消失,退出
if
torch
.
isnan
(
norm
)
or
torch
.
isinf
(
norm
):
return
None
# inner sum
noise
=
noise
+
delta_grad
*
self
.
adv_step_size
# projection
noise
=
self
.
adv_project
(
noise
,
norm_type
=
self
.
norm_type
,
eps
=
self
.
noise_gamma
)
new_embed
=
self
.
embed
.
data
.
detach
()
+
noise
new_embed
=
new_embed
.
detach
()
# 在进行一次训练
adv_output
=
self
.
forward_
(
train_X
,
new_embed
)
# forward第二次
adv_logits
=
adv_output
[
0
]
if
isinstance
(
adv_output
,
(
list
,
tuple
))
else
adv_output
adv_loss_f
=
self
.
kl
(
adv_logits
,
logits
.
detach
())
adv_loss_b
=
self
.
kl
(
logits
,
adv_logits
.
detach
())
# 在预训练时设置为10,下游任务设置为1
adv_loss
=
(
adv_loss_f
+
adv_loss_b
)
*
self
.
adv_alpha
return
adv_loss
@
staticmethod
def
kl
(
inputs
,
targets
,
reduction
=
"sum"
):
"""
计算kl散度
inputs:tensor,logits
targets:tensor,logits
"""
loss
=
F
.
kl_div
(
F
.
log_softmax
(
inputs
,
dim
=-
1
),
F
.
softmax
(
targets
,
dim
=-
1
),
reduction
=
reduction
)
return
loss
@
staticmethod
def
adv_project
(
grad
,
norm_type
=
'inf'
,
eps
=
1e-6
):
"""
L0,L1,L2正则,对于扰动计算
"""
if
norm_type
==
'l2'
:
direction
=
grad
/
(
torch
.
norm
(
grad
,
dim
=-
1
,
keepdim
=
True
)
+
eps
)
elif
norm_type
==
'l1'
:
direction
=
grad
.
sign
()
else
:
direction
=
grad
/
(
grad
.
abs
().
max
(
-
1
,
keepdim
=
True
)[
0
]
+
eps
)
return
direction
class
WebServing
(
object
):
"""简单的Web接口
用法:
arguments = {'text': (None, True), 'n': (int, False)}
web = WebServing(port=8864)
web.route('/gen_synonyms', gen_synonyms, arguments)
web.start()
# 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好
说明:
基于bottlepy简单封装,仅作为临时测试使用,不保证性能。
目前仅保证支持 Tensorflow 1.x + Keras <= 2.3.1。
欢迎有经验的开发者帮忙改进。
依赖:
pip install bottle
pip install paste
(如果不用 server='paste' 的话,可以不装paste库)
"""
def
__init__
(
self
,
host
=
'0.0.0.0'
,
port
=
8000
,
server
=
'paste'
):
import
bottle
self
.
host
=
host
self
.
port
=
port
self
.
server
=
server
self
.
bottle
=
bottle
def
wraps
(
self
,
func
,
arguments
,
method
=
'GET'
):
"""封装为接口函数
参数:
func:要转换为接口的函数,需要保证输出可以json化,即需要
保证 json.dumps(func(inputs)) 能被执行成功;
arguments:声明func所需参数,其中key为参数名,value[0]为
对应的转换函数(接口获取到的参数值都是字符串
型),value[1]为该参数是否必须;
method:GET或者POST。
"""
def
new_func
():
outputs
=
{
'code'
:
0
,
'desc'
:
u
'succeeded'
,
'data'
:
{}}
kwargs
=
{}
for
key
,
value
in
arguments
.
items
():
if
method
==
'GET'
:
result
=
self
.
bottle
.
request
.
GET
.
getunicode
(
key
)
else
:
result
=
self
.
bottle
.
request
.
POST
.
getunicode
(
key
)
if
result
is
None
:
if
value
[
1
]:
outputs
[
'code'
]
=
1
outputs
[
'desc'
]
=
'lack of "%s" argument'
%
key
return
json
.
dumps
(
outputs
,
ensure_ascii
=
False
)
else
:
if
value
[
0
]
is
not
None
:
result
=
value
[
0
](
result
)
kwargs
[
key
]
=
result
try
:
outputs
[
'data'
]
=
func
(
**
kwargs
)
except
Exception
as
e
:
outputs
[
'code'
]
=
2
outputs
[
'desc'
]
=
str
(
e
)
return
json
.
dumps
(
outputs
,
ensure_ascii
=
False
)
return
new_func
def
route
(
self
,
path
,
func
,
arguments
,
method
=
'GET'
):
"""添加接口
"""
func
=
self
.
wraps
(
func
,
arguments
,
method
)
self
.
bottle
.
route
(
path
,
method
=
method
)(
func
)
def
start
(
self
):
"""启动服务
"""
self
.
bottle
.
run
(
host
=
self
.
host
,
port
=
self
.
port
,
server
=
self
.
server
)
def
get_pool_emb
(
hidden_state
=
None
,
pooler
=
None
,
attention_mask
=
None
,
pool_strategy
=
'cls'
,
custom_layer
=
None
):
''' 获取句向量
'''
if
pool_strategy
==
'pooler'
:
return
pooler
elif
pool_strategy
==
'cls'
:
if
isinstance
(
hidden_state
,
(
list
,
tuple
)):
hidden_state
=
hidden_state
[
-
1
]
assert
isinstance
(
hidden_state
,
torch
.
Tensor
),
f
'
{
pool_strategy
}
strategy request tensor hidden_state'
return
hidden_state
[:,
0
]
elif
pool_strategy
in
{
'last-avg'
,
'mean'
}:
if
isinstance
(
hidden_state
,
(
list
,
tuple
)):
hidden_state
=
hidden_state
[
-
1
]
assert
isinstance
(
hidden_state
,
torch
.
Tensor
),
f
'
{
pool_strategy
}
pooling strategy request tensor hidden_state'
hid
=
torch
.
sum
(
hidden_state
*
attention_mask
[:,
:,
None
],
dim
=
1
)
attention_mask
=
torch
.
sum
(
attention_mask
,
dim
=
1
)[:,
None
]
return
hid
/
attention_mask
elif
pool_strategy
in
{
'last-max'
,
'max'
}:
if
isinstance
(
hidden_state
,
(
list
,
tuple
)):
hidden_state
=
hidden_state
[
-
1
]
assert
isinstance
(
hidden_state
,
torch
.
Tensor
),
f
'
{
pool_strategy
}
pooling strategy request tensor hidden_state'
hid
=
hidden_state
*
attention_mask
[:,
:,
None
]
return
torch
.
max
(
hid
,
dim
=
1
)
elif
pool_strategy
==
'first-last-avg'
:
assert
isinstance
(
hidden_state
,
list
),
f
'
{
pool_strategy
}
pooling strategy request list hidden_state'
hid
=
torch
.
sum
(
hidden_state
[
1
]
*
attention_mask
[:,
:,
None
],
dim
=
1
)
# 这里不取0
hid
+=
torch
.
sum
(
hidden_state
[
-
1
]
*
attention_mask
[:,
:,
None
],
dim
=
1
)
attention_mask
=
torch
.
sum
(
attention_mask
,
dim
=
1
)[:,
None
]
return
hid
/
(
2
*
attention_mask
)
elif
pool_strategy
==
'custom'
:
# 取指定层
assert
isinstance
(
hidden_state
,
list
),
f
'
{
pool_strategy
}
pooling strategy request list hidden_state'
assert
isinstance
(
custom_layer
,
(
int
,
list
,
tuple
)),
f
'
{
pool_strategy
}
pooling strategy request int/list/tuple custom_layer'
custom_layer
=
[
custom_layer
]
if
isinstance
(
custom_layer
,
int
)
else
custom_layer
hid
=
0
for
i
,
layer
in
enumerate
(
custom_layer
,
start
=
1
):
hid
+=
torch
.
sum
(
hidden_state
[
layer
]
*
attention_mask
[:,
:,
None
],
dim
=
1
)
attention_mask
=
torch
.
sum
(
attention_mask
,
dim
=
1
)[:,
None
]
return
hid
/
(
i
*
attention_mask
)
else
:
raise
ValueError
(
'pool_strategy illegal'
)
def
seed_everything
(
seed
=
None
):
'''固定seed
'''
max_seed_value
=
np
.
iinfo
(
np
.
uint32
).
max
min_seed_value
=
np
.
iinfo
(
np
.
uint32
).
min
if
(
seed
is
None
)
or
not
(
min_seed_value
<=
seed
<=
max_seed_value
):
random
.
randint
(
np
.
iinfo
(
np
.
uint32
).
min
,
np
.
iinfo
(
np
.
uint32
).
max
)
print
(
f
"Global seed set to
{
seed
}
"
)
os
.
environ
[
"PYTHONHASHSEED"
]
=
str
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
return
seed
\ No newline at end of file
build/lib/bert4torch/tokenizers.py
0 → 100644
View file @
66a1d0d0
# coding=utf-8
"""Tokenization classes."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
logging
import
unicodedata
from
io
import
open
from
bert4torch.snippets
import
truncate_sequences
,
is_string
,
lowercase_and_normalize
import
re
import
six
from
collections
import
OrderedDict
logger
=
logging
.
getLogger
(
__name__
)
is_py2
=
six
.
PY2
def
load_vocab
(
dict_path
,
encoding
=
"utf-8"
,
simplified
=
False
,
startswith
=
None
):
"""加载词典文件到dict"""
token_dict
=
collections
.
OrderedDict
()
index
=
0
with
open
(
dict_path
,
"r"
,
encoding
=
encoding
)
as
reader
:
while
True
:
token
=
reader
.
readline
()
if
not
token
:
break
token
=
token
.
strip
()
token_dict
[
token
]
=
index
index
+=
1
if
simplified
:
# 过滤冗余部分token,如[unused1]
new_token_dict
,
keep_tokens
=
{},
[]
startswith
=
startswith
or
[]
for
t
in
startswith
:
new_token_dict
[
t
]
=
len
(
new_token_dict
)
keep_tokens
.
append
(
token_dict
[
t
])
for
t
,
_
in
sorted
(
token_dict
.
items
(),
key
=
lambda
s
:
s
[
1
]):
if
t
not
in
new_token_dict
and
not
Tokenizer
.
_is_redundant
(
t
):
new_token_dict
[
t
]
=
len
(
new_token_dict
)
keep_tokens
.
append
(
token_dict
[
t
])
return
new_token_dict
,
keep_tokens
else
:
return
token_dict
def
whitespace_tokenize
(
text
):
"""去除文本中的空白符"""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
TokenizerBase
(
object
):
"""分词器基类
"""
def
__init__
(
self
,
token_start
=
'[CLS]'
,
token_end
=
'[SEP]'
,
token_unk
=
'[UNK]'
,
token_pad
=
'[PAD]'
,
token_mask
=
'[MASK]'
,
add_special_tokens
=
None
,
pre_tokenize
=
None
,
token_translate
=
None
):
"""参数说明:
token_unk: 未知词标记
token_end: 句子切分标记,当只有一句话作为输入时,此标记知识作为结束符;当有两句话作为输入时,此标记作为分隔符、最后一句话的结束符
pad_token: padding填充标记
token_start: 分类标记,位于整个序列的第一个
mask_token: mask标记
pre_tokenize: 外部传入的分词函数,用作对文本进行预分词。如果传入pre_tokenize,则先执行pre_tokenize(text),然后在它的基础上执行原本的tokenize函数;
token_translate: 映射字典,主要用在tokenize之后,将某些特殊的token替换为对应的token。
"""
self
.
_token_pad
=
token_pad
self
.
_token_unk
=
token_unk
self
.
_token_mask
=
token_mask
self
.
_token_start
=
token_start
self
.
_token_end
=
token_end
self
.
never_split
=
[
self
.
_token_unk
,
self
.
_token_end
,
self
.
_token_pad
,
self
.
_token_start
,
self
.
_token_mask
]
if
add_special_tokens
is
not
None
:
if
isinstance
(
add_special_tokens
,
(
tuple
,
list
)):
self
.
never_split
.
extend
(
add_special_tokens
)
elif
isinstance
(
add_special_tokens
,
str
):
self
.
never_split
.
append
(
add_special_tokens
)
self
.
tokens_trie
=
self
.
_create_trie
(
self
.
never_split
)
# trie树主要是为了special_tokens的分词
self
.
_pre_tokenize
=
pre_tokenize
self
.
_token_translate
=
token_translate
or
{}
self
.
_token_translate_inv
=
{
v
:
k
for
k
,
v
in
self
.
_token_translate
.
items
()}
def
_create_trie
(
self
,
unique_no_split_tokens
):
trie
=
Trie
()
for
token
in
unique_no_split_tokens
:
trie
.
add
(
token
)
return
trie
def
tokenize
(
self
,
text
,
maxlen
=
None
):
"""分词函数
"""
tokens
=
[
self
.
_token_translate
.
get
(
token
)
or
token
for
token
in
self
.
_tokenize
(
text
)]
if
self
.
_token_start
is
not
None
:
tokens
.
insert
(
0
,
self
.
_token_start
)
if
self
.
_token_end
is
not
None
:
tokens
.
append
(
self
.
_token_end
)
if
maxlen
is
not
None
:
index
=
int
(
self
.
_token_end
is
not
None
)
+
1
truncate_sequences
(
maxlen
,
-
index
,
tokens
)
return
tokens
def
token_to_id
(
self
,
token
):
"""token转换为对应的id
"""
raise
NotImplementedError
def
tokens_to_ids
(
self
,
tokens
):
"""token序列转换为对应的id序列
"""
return
[
self
.
token_to_id
(
token
)
for
token
in
tokens
]
def
_encode
(
self
,
first_text
,
second_text
=
None
,
maxlen
=
None
,
pattern
=
'S*E*E'
,
truncate_from
=
'right'
,
return_offsets
=
False
):
"""输出文本对应token id和segment id
"""
first_tokens
=
self
.
tokenize
(
first_text
)
if
is_string
(
first_text
)
else
first_text
if
second_text
is
None
:
second_tokens
=
None
elif
is_string
(
second_text
):
second_tokens
=
self
.
tokenize
(
second_text
)
else
:
second_tokens
=
second_text
if
maxlen
is
not
None
:
# 这里截断思路是优先截断最长的子句
if
truncate_from
==
'right'
:
index
=
-
int
(
self
.
_token_end
is
not
None
)
-
1
elif
truncate_from
==
'left'
:
index
=
int
(
self
.
_token_start
is
not
None
)
else
:
index
=
truncate_from
if
second_text
is
not
None
and
pattern
==
'S*E*E'
:
maxlen
+=
1
truncate_sequences
(
maxlen
,
index
,
first_tokens
,
second_tokens
)
first_token_ids
=
self
.
tokens_to_ids
(
first_tokens
)
first_segment_ids
=
[
0
]
*
len
(
first_token_ids
)
if
second_text
is
not
None
:
if
pattern
==
'S*E*E'
:
idx
=
int
(
bool
(
self
.
_token_start
))
second_tokens
=
second_tokens
[
idx
:]
second_token_ids
=
self
.
tokens_to_ids
(
second_tokens
)
second_segment_ids
=
[
1
]
*
len
(
second_token_ids
)
first_token_ids
.
extend
(
second_token_ids
)
first_segment_ids
.
extend
(
second_segment_ids
)
encode_output
=
[
first_token_ids
,
first_segment_ids
]
if
return_offsets
!=
False
:
offset
=
self
.
rematch
(
first_text
,
first_tokens
)
+
self
.
rematch
(
second_text
,
second_tokens
)
if
return_offsets
==
'transformers'
:
# transformers包中tokenizer的形式
encode_output
.
append
([[
0
,
0
]
if
not
k
else
[
k
[
0
],
k
[
-
1
]
+
1
]
for
k
in
offset
])
else
:
encode_output
.
append
(
offset
)
return
encode_output
def
encode
(
self
,
first_texts
,
second_texts
=
None
,
maxlen
=
None
,
pattern
=
'S*E*E'
,
truncate_from
=
'right'
,
return_offsets
=
False
):
'''可以处理多条或者单条
'''
return_list
=
False
if
isinstance
(
first_texts
,
str
)
else
True
first_texts
=
[
first_texts
]
if
isinstance
(
first_texts
,
str
)
else
first_texts
second_texts
=
[
second_texts
]
if
isinstance
(
second_texts
,
str
)
else
second_texts
first_token_ids
,
first_segment_ids
,
offsets
=
[],
[],
[]
if
second_texts
is
None
:
second_texts
=
[
None
]
*
len
(
first_texts
)
assert
len
(
first_texts
)
==
len
(
second_texts
),
'first_texts and second_texts should be same length'
# 循环处理每条样本
for
first_text
,
second_text
in
zip
(
first_texts
,
second_texts
):
outputs
=
self
.
_encode
(
first_text
,
second_text
,
maxlen
,
pattern
,
truncate_from
,
return_offsets
)
first_token_ids
.
append
(
outputs
[
0
])
first_segment_ids
.
append
(
outputs
[
1
])
if
len
(
outputs
)
>=
3
:
offsets
.
append
(
outputs
[
2
])
encode_outputs
=
[
first_token_ids
,
first_segment_ids
]
if
return_offsets
:
encode_outputs
.
append
(
offsets
)
if
not
return_list
:
# 如果输入是string
encode_outputs
=
[
item
[
0
]
for
item
in
encode_outputs
]
return
encode_outputs
def
id_to_token
(
self
,
i
):
"""id序列为对应的token
"""
raise
NotImplementedError
def
ids_to_tokens
(
self
,
ids
):
"""id序列转换为对应的token序列
"""
return
[
self
.
id_to_token
(
i
)
for
i
in
ids
]
def
decode
(
self
,
ids
):
"""转为可读文本
"""
raise
NotImplementedError
def
_tokenize
(
self
,
text
):
"""基本分词函数
"""
raise
NotImplementedError
def
rematch
(
self
):
"""生成text和tokens之间的对应关系
"""
pass
class
Tokenizer
(
TokenizerBase
):
"""Bert原生分词器
"""
def
__init__
(
self
,
token_dict
,
do_lower_case
=
True
,
do_basic_tokenize
=
True
,
do_tokenize_unk
=
False
,
**
kwargs
):
"""
参数:
token_dict:
词典文件
do_lower_case:
是否转换成小写
do_basic_tokenize:
分词前,是否进行基础的分词
do_tokenize_unk:
分词后,是否生成[UNK]标记,还是在encode阶段生成
"""
super
(
Tokenizer
,
self
).
__init__
(
**
kwargs
)
if
is_string
(
token_dict
):
token_dict
=
load_vocab
(
token_dict
)
self
.
_do_lower_case
=
do_lower_case
self
.
_vocab_size
=
len
(
token_dict
)
self
.
_token_dict
=
token_dict
self
.
_token_dict_inv
=
{
v
:
k
for
k
,
v
in
token_dict
.
items
()}
self
.
do_basic_tokenize
=
do_basic_tokenize
if
do_basic_tokenize
:
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
,
never_split
=
self
.
never_split
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
_token_dict
,
unk_token
=
self
.
_token_unk
,
do_tokenize_unk
=
do_tokenize_unk
)
for
token
in
[
'pad'
,
'unk'
,
'mask'
,
'start'
,
'end'
]:
try
:
_token_id
=
token_dict
[
getattr
(
self
,
'_token_%s'
%
token
)]
setattr
(
self
,
'_token_%s_id'
%
token
,
_token_id
)
except
:
pass
def
_tokenize
(
self
,
text
,
pre_tokenize
=
True
):
"""基本分词函数
"""
# 以下pre_tokenizer逻辑参考bert4keras
if
self
.
_do_lower_case
:
text
=
lowercase_and_normalize
(
text
,
never_split
=
self
.
never_split
)
if
pre_tokenize
and
self
.
_pre_tokenize
is
not
None
:
tokens
=
[]
for
token
in
self
.
_pre_tokenize
(
text
):
if
token
in
self
.
_token_dict
:
tokens
.
append
(
token
)
else
:
tokens
.
extend
(
self
.
_tokenize
(
token
,
False
))
return
tokens
# 以下逻辑参考pytorch版本bert分词器自己的
text_pieces
=
self
.
tokens_trie
.
split
(
text
)
# 新增逻辑,主要是special_tokens的分词
split_tokens
=
[]
for
text_piece
in
text_pieces
:
if
not
text_piece
:
continue
elif
text_piece
in
self
.
_token_dict
:
split_tokens
.
append
(
text_piece
)
elif
self
.
do_basic_tokenize
:
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text_piece
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
else
:
split_tokens
.
extend
(
self
.
wordpiece_tokenizer
.
tokenize
(
text_piece
))
return
split_tokens
def
token_to_id
(
self
,
token
):
"""token转为vocab中的id"""
return
self
.
_token_dict
.
get
(
token
,
self
.
_token_unk_id
)
def
id_to_token
(
self
,
id
):
"""id转为词表中的token"""
return
self
.
_token_dict_inv
[
id
]
def
decode
(
self
,
ids
,
tokens
=
None
):
"""转为可读文本
"""
tokens
=
tokens
or
self
.
ids_to_tokens
(
ids
)
tokens
=
[
token
for
token
in
tokens
if
not
self
.
_is_special
(
token
)]
text
,
flag
=
''
,
False
for
i
,
token
in
enumerate
(
tokens
):
if
token
[:
2
]
==
'##'
:
text
+=
token
[
2
:]
elif
len
(
token
)
==
1
and
self
.
_is_cjk_character
(
token
):
text
+=
token
elif
len
(
token
)
==
1
and
self
.
_is_punctuation
(
token
):
text
+=
token
text
+=
' '
elif
i
>
0
and
self
.
_is_cjk_character
(
text
[
-
1
]):
text
+=
token
else
:
text
+=
' '
text
+=
token
text
=
re
.
sub
(
' +'
,
' '
,
text
)
text
=
re
.
sub
(
'
\'
(re|m|s|t|ve|d|ll) '
,
'
\'\\
1 '
,
text
)
punctuation
=
self
.
_cjk_punctuation
()
+
'+-/={(<['
punctuation_regex
=
'|'
.
join
([
re
.
escape
(
p
)
for
p
in
punctuation
])
punctuation_regex
=
'(%s) '
%
punctuation_regex
text
=
re
.
sub
(
punctuation_regex
,
'
\\
1'
,
text
)
text
=
re
.
sub
(
'(\d\.) (\d)'
,
'
\\
1
\\
2'
,
text
)
return
text
.
strip
()
@
staticmethod
def
stem
(
token
):
"""获取token的“词干”(如果是##开头,则自动去掉##)
"""
if
token
[:
2
]
==
'##'
:
return
token
[
2
:]
else
:
return
token
@
staticmethod
def
_is_space
(
ch
):
"""空格类字符判断
"""
return
ch
==
' '
or
ch
==
'
\n
'
or
ch
==
'
\r
'
or
ch
==
'
\t
'
or
\
unicodedata
.
category
(
ch
)
==
'Zs'
@
staticmethod
def
_is_punctuation
(
ch
):
"""标点符号类字符判断(全/半角均在此内)
提醒:unicodedata.category这个函数在py2和py3下的
表现可能不一样,比如u'§'字符,在py2下的结果为'So',
在py3下的结果是'Po'。
"""
code
=
ord
(
ch
)
return
33
<=
code
<=
47
or
\
58
<=
code
<=
64
or
\
91
<=
code
<=
96
or
\
123
<=
code
<=
126
or
\
unicodedata
.
category
(
ch
).
startswith
(
'P'
)
@
staticmethod
def
_cjk_punctuation
():
return
u
'
\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002
'
@
staticmethod
def
_is_cjk_character
(
ch
):
"""CJK类字符判断(包括中文字符也在此列)
参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
"""
code
=
ord
(
ch
)
return
0x4E00
<=
code
<=
0x9FFF
or
\
0x3400
<=
code
<=
0x4DBF
or
\
0x20000
<=
code
<=
0x2A6DF
or
\
0x2A700
<=
code
<=
0x2B73F
or
\
0x2B740
<=
code
<=
0x2B81F
or
\
0x2B820
<=
code
<=
0x2CEAF
or
\
0xF900
<=
code
<=
0xFAFF
or
\
0x2F800
<=
code
<=
0x2FA1F
@
staticmethod
def
_is_control
(
ch
):
"""控制类字符判断
"""
return
unicodedata
.
category
(
ch
)
in
(
'Cc'
,
'Cf'
)
@
staticmethod
def
_is_special
(
ch
):
"""判断是不是有特殊含义的符号
"""
return
bool
(
ch
)
and
(
ch
[
0
]
==
'['
)
and
(
ch
[
-
1
]
==
']'
)
@
staticmethod
def
_is_redundant
(
token
):
"""判断该token是否冗余(默认情况下不可能分出来)
"""
if
len
(
token
)
>
1
:
for
ch
in
Tokenizer
.
stem
(
token
):
if
(
Tokenizer
.
_is_cjk_character
(
ch
)
or
Tokenizer
.
_is_punctuation
(
ch
)
):
return
True
def
rematch
(
self
,
text
,
tokens
):
"""给出原始的text和tokenize后的tokens的映射关系
"""
if
is_py2
:
text
=
unicode
(
text
)
if
self
.
_do_lower_case
:
text
=
text
.
lower
()
normalized_text
,
char_mapping
=
''
,
[]
for
i
,
ch
in
enumerate
(
text
):
if
self
.
_do_lower_case
:
ch
=
lowercase_and_normalize
(
ch
,
self
.
never_split
)
ch
=
''
.
join
([
c
for
c
in
ch
if
not
(
ord
(
c
)
==
0
or
ord
(
c
)
==
0xfffd
or
self
.
_is_control
(
c
))
])
normalized_text
+=
ch
char_mapping
.
extend
([
i
]
*
len
(
ch
))
text
,
token_mapping
,
offset
=
normalized_text
,
[],
0
for
token
in
tokens
:
if
self
.
_is_special
(
token
):
token_mapping
.
append
([])
else
:
token
=
self
.
stem
(
token
)
start
=
text
[
offset
:].
index
(
token
)
+
offset
end
=
start
+
len
(
token
)
token_mapping
.
append
(
char_mapping
[
start
:
end
])
offset
=
end
return
token_mapping
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
,
never_split
=
(
"[UNK]"
,
"[SEP]"
,
"[PAD]"
,
"[CLS]"
,
"[MASK]"
)):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
self
.
never_split
=
never_split
def
tokenize
(
self
,
text
):
"""文本切分成token"""
text
=
self
.
_clean_text
(
text
)
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
and
token
not
in
self
.
never_split
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
if
text
in
self
.
never_split
:
return
[
text
]
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenization."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
,
do_tokenize_unk
=
False
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
self
.
do_tokenize_unk
=
do_tokenize_unk
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer`.
Returns:
A list of wordpiece tokens.
"""
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
if
self
.
do_tokenize_unk
else
token
)
# 超长
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
(
substr
in
self
.
vocab
)
or
(
not
self
.
do_tokenize_unk
):
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
self
.
do_tokenize_unk
and
is_bad
:
# 是否在tokenize阶段转UNK
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
class
SpTokenizer
(
TokenizerBase
):
"""基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。
"""
def
__init__
(
self
,
sp_model_path
,
**
kwargs
):
super
(
SpTokenizer
,
self
).
__init__
(
**
kwargs
)
import
sentencepiece
as
spm
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
sp_model_path
)
self
.
_token_pad
=
self
.
sp_model
.
id_to_piece
(
self
.
sp_model
.
pad_id
())
self
.
_token_unk
=
self
.
sp_model
.
id_to_piece
(
self
.
sp_model
.
unk_id
())
self
.
_vocab_size
=
self
.
sp_model
.
get_piece_size
()
for
token
in
[
'pad'
,
'unk'
,
'mask'
,
'start'
,
'end'
]:
try
:
_token
=
getattr
(
self
,
'_token_%s'
%
token
)
_token_id
=
self
.
sp_model
.
piece_to_id
(
_token
)
setattr
(
self
,
'_token_%s_id'
%
token
,
_token_id
)
except
:
pass
def
token_to_id
(
self
,
token
):
"""token转换为对应的id
"""
return
self
.
sp_model
.
piece_to_id
(
token
)
def
id_to_token
(
self
,
i
):
"""id转换为对应的token
"""
if
i
<
self
.
_vocab_size
:
return
self
.
sp_model
.
id_to_piece
(
i
)
else
:
return
''
def
decode
(
self
,
ids
):
"""转为可读文本
"""
tokens
=
[
self
.
_token_translate_inv
.
get
(
token
)
or
token
for
token
in
self
.
ids_to_tokens
(
ids
)]
text
=
self
.
sp_model
.
decode_pieces
(
tokens
)
return
convert_to_unicode
(
text
)
def
_tokenize
(
self
,
text
):
"""基本分词函数
"""
if
self
.
_pre_tokenize
is
not
None
:
text
=
' '
.
join
(
self
.
_pre_tokenize
(
text
))
tokens
=
self
.
sp_model
.
encode_as_pieces
(
text
)
return
tokens
def
_is_special
(
self
,
i
):
"""判断是不是有特殊含义的符号
"""
return
self
.
sp_model
.
is_control
(
i
)
or
\
self
.
sp_model
.
is_unknown
(
i
)
or
\
self
.
sp_model
.
is_unused
(
i
)
def
_is_decodable
(
self
,
i
):
"""判断是否应该被解码输出
"""
return
(
i
<
self
.
_vocab_size
)
and
not
self
.
_is_special
(
i
)
class
Trie
:
"""直接从transformer的tokenization_utils.py中移植, 主要是为了special_tokens分词
"""
def
__init__
(
self
):
self
.
data
=
{}
def
add
(
self
,
word
:
str
):
if
not
word
:
# Prevent empty string
return
ref
=
self
.
data
for
char
in
word
:
ref
[
char
]
=
char
in
ref
and
ref
[
char
]
or
{}
ref
=
ref
[
char
]
ref
[
""
]
=
1
def
split
(
self
,
text
:
str
):
states
=
OrderedDict
()
# This will contain every indices where we need
# to cut.
# We force to cut at offset 0 and len(text) (added later)
offsets
=
[
0
]
# This is used by the lookahead which needs to skip over
# some text where the full match exceeded the place in the initial
# for loop
skip
=
0
# Main loop, Giving this algorithm O(n) complexity
for
current
,
current_char
in
enumerate
(
text
):
if
skip
and
current
<
skip
:
# Prevents the lookahead for matching twice
# like extra_id_100 and id_100
continue
# This will track every state
# that stop matching, we need to stop tracking them.
# If we look at "lowball", we're going to match "l" (add it to states), "o", "w", then
# fail on "b", we need to remove 0 from the valid states.
to_remove
=
set
()
# Whenever we found a match, we need to drop everything
# this is a greedy algorithm, it will match on the first found token
reset
=
False
# In this case, we already have partial matches (But unfinished)
for
start
,
trie_pointer
in
states
.
items
():
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# store the results in `offsets`.
# Lookahead to match longest first
# Important in case of extra_id_1 vs extra_id_100
# Here we are also actively looking for other earlier partial
# matches
# "[CLS]", "L", we need to match CLS even if L is special
for
lookstart
,
looktrie_pointer
in
states
.
items
():
if
lookstart
>
start
:
# This partial match is later, we can stop looking
break
elif
lookstart
<
start
:
# This partial match is earlier, the trie pointer
# was already updated, so index is + 1
lookahead_index
=
current
+
1
end
=
current
+
1
else
:
# Here lookstart == start and
# looktrie_pointer == trie_pointer
# It wasn't updated yet so indices are current ones
lookahead_index
=
current
end
=
current
next_char
=
text
[
lookahead_index
]
if
lookahead_index
<
len
(
text
)
else
None
if
""
in
looktrie_pointer
:
start
=
lookstart
end
=
lookahead_index
skip
=
lookahead_index
while
next_char
in
looktrie_pointer
:
looktrie_pointer
=
looktrie_pointer
[
next_char
]
lookahead_index
+=
1
if
""
in
looktrie_pointer
:
start
=
lookstart
end
=
lookahead_index
skip
=
lookahead_index
if
lookahead_index
==
len
(
text
):
# End of string
break
next_char
=
text
[
lookahead_index
]
# End lookahead
# Storing and resetting
offsets
.
append
(
start
)
offsets
.
append
(
end
)
reset
=
True
break
elif
current_char
in
trie_pointer
:
# The current character being looked at has a match within the trie
# update the pointer (it will be stored back into states later).
trie_pointer
=
trie_pointer
[
current_char
]
# Storing back the new pointer into the states.
# Partial matches got longer by one.
states
[
start
]
=
trie_pointer
else
:
# The new character has not match in the trie, we need
# to stop keeping track of this partial match.
# We can't do it directly within the loop because of how
# python iteration works
to_remove
.
add
(
start
)
# Either clearing the full start (we found a real match)
# Or clearing only the partial matches that didn't work.
if
reset
:
states
=
{}
else
:
for
start
in
to_remove
:
del
states
[
start
]
# If this character is a starting character within the trie
# start keeping track of this partial match.
if
current
>=
skip
and
current_char
in
self
.
data
:
states
[
current
]
=
self
.
data
[
current_char
]
# We have a cut at the end with states.
for
start
,
trie_pointer
in
states
.
items
():
if
""
in
trie_pointer
:
# This is a final match, we need to reset and
# store the results in `offsets`.
end
=
len
(
text
)
offsets
.
append
(
start
)
offsets
.
append
(
end
)
# Longest cut is always the one with lower start so the first
# item so we need to break.
break
return
self
.
cut_text
(
text
,
offsets
)
def
cut_text
(
self
,
text
,
offsets
):
# We have all the offsets now, we just need to do the actual splitting.
# We need to eventually add the first part of the string and the eventual
# last part.
offsets
.
append
(
len
(
text
))
tokens
=
[]
start
=
0
for
end
in
offsets
:
if
start
>
end
:
logger
.
error
(
"There was a bug in Trie algorithm in tokenization. Attempting to recover. Please report it anyway."
)
continue
elif
start
==
end
:
# This might happen if there's a match at index 0
# we're also preventing zero-width cuts in case of two
# consecutive matches
continue
tokens
.
append
(
text
[
start
:
end
])
start
=
end
return
tokens
dist/bert4torch-0.1.9-py3.9.egg
0 → 100644
View file @
66a1d0d0
File added
examples/Performance.md
0 → 100644
View file @
66a1d0d0
# 1. 文本分类
## 1.1 不同预训练模型的指标对比
-
[
情感分类数据集
](
https://github.com/bojone/bert4keras/blob/master/examples/datasets/sentiment.zip
)
+cls位分类
| solution | epoch | valid_acc | test_acc | comment |
| ---- | ---- | ---- | ---- | ---- |
| albert_small | 10/10 | 94.46 | 93.98 | small版本 |
| bert | 6/10 | 94.72 | 94.11 | —— |
| robert | 4/10 | 94.77 | 94.64 | —— |
| nezha | 7/10 | 95.07 | 94.72 | —— |
| xlnet | 6/10 | 95.00 | 94.24 | —— |
| electra | 10/10 | 94.94 | 94.78 | —— |
| roformer | 9/10 | 94.85 | 94.42 | —— |
| roformer_v2 | 3/10 | 95.78 | 96.09 | —— |
| gau_alpha | 2/10 | 95.25 | 94.46 | —— |
## 1.2 不同trick下的指标对比
-
trick测试+
[
情感分类数据集
](
https://github.com/bojone/bert4keras/blob/master/examples/datasets/sentiment.zip
)
+cls分类+无segment_input
| solution | epoch | valid_acc | test_acc | comment |
| ---- | ---- | ---- | ---- | ---- |
| bert | 10/10 | 94.90 | 94.78 | —— |
| fgm | 4/10 | 95.34 | 94.99 | —— |
| pgd | 6/10 | 95.34 | 94.64 | —— |
| gradient_penalty | 7/10 | 95.07 | 94.81 | —— |
| vat | 8/10 | 95.21 | 95.03 | —— |
| ema | 7/10 | 95.21 | 94.86 | —— |
| mix_up | 6/10 | 95.12 | 94.42 | —— |
| R-drop | 9/10 | 95.25 | 94.94 | —— |
| UDA | 8/10 | 94.90 | 95.56 | —— |
| semi-vat | 10/10 | 95.34 | 95.38 | —— |
| temporal_ensembling | 8/10 | 94.94 | 94.90 | —— |
# 2. 序列标注
-
[
人民日报数据集
](
http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
)
+bert预训练模型
-
valid集指标
| solution | epoch | f1_token | f1_entity | comment |
| ---- | ---- | ---- | ---- | ---- |
| bert+crf | 18/20 | 96.89 | 96.05 | —— |
| bert+crf+init | 18/20 | 96.93 | 96.08 | 用训练数据初始化crf权重 |
| bert+crf+freeze | 11/20 | 96.89 | 96.13 | 用训练数据生成crf权重(不训练) |
| bert+cascade+crf | 5/20 | 98.10 | 96.26 | crf类别少所以f1_token偏高 |
| bert+crf+posseg | 13/20 | 97.32 | 96.55 | 加了词性输入 |
| bert+global_pointer | 18/20 | —— | 95.66 | —— |
| bert+efficient_global_pointer | 17/20 | —— | 96.55 | —— |
| bert+mrc | 7/20 | —— | 95.75 | —— |
| bert+span | 13/20 | —— | 96.31 | —— |
| bert+tplinker_plus | 20/20 | —— | 95.71 | 长度限制明显 |
| uie | 20/20 | —— | 96.57 | zeroshot:f1=60.8, fewshot-100样本:f1=85.82, 200样本:f1=86.40 |
# 3. 文本表示
## 3.1 无监督语义相似度
-
bert预训练模型 + 无监督finetune + cls位句向量(PromptBert除外)
-
五个中文数据集 + 5个epoch取最优值 + valid的spearmanr相关系数
-
继续finetune, 部分数据集有小幅提升
-
实验显示dropout_rate对结果影响较大
| solution | ATEC | BQ | LCQMC | PAWSX | STS-B | comment |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| Bert-whitening | 26.79 | 31.81| 56.34 | 17.22 | 67.45 | cls+不降维 |
| CT | 30.65 | 44.50| 68.67 | 16.20 | 69.27 | dropout=0.1, 收敛慢跑了10个epoch |
| CT_In_Batch_Neg | 32.47 | 47.09| 68.56 | 27.50 | 74.00 | dropout=0.1 |
| TSDAE | —— | 46.65| 65.30 | 12.54 | —— | dropout=0.1, ——表示该指标异常未记录 |
| SimCSE | 33.90 | 50.29| 71.81 | 13.14 | 71.09 | dropout=0.3 |
| ESimCSE | 34.05 | 50.54| 71.58 | 12.53 | 71.27 | dropout=0.3 |
| PromptBert | 33.98 | 49.89| 73.18 | 13.30 | 73.42 | dropout=0.3 |
## 3.2 有监督语义相似度
-
bert预训练模型 + 训练数据finetune + cls位句向量
-
五个中文数据集 + 5个epoch取最优值 + valid/test的spearmanr相关系数
-
STS-B任务是5分类,其余是2分类
| solution | ATEC | BQ | LCQMC | PAWSX | STS-B | comment |
| ---- | ---- | ---- | ---- | ---- | ---- | ---- |
| CoSENT |50.61 / 49.81|72.84 / 71.61|77.79 / 78.74|55.00 / 56.00|83.48 / 80.06| |
| ContrastiveLoss |50.02 / 49.19|72.52 / 70.98|77.49 / 78.27|58.21 / 57.65|69.87 / 68.58| STS-B转为2分类 |
| InfoNCE |47.77 / 46.99|69.86 / 68.14|71.74 / 74.54|52.82 / 54.21|83.31 / 78.72| STS-B转为2分类 |
|concat CrossEntropy|48.71 / 47.62|72.16 / 70.07|78.44 / 78.77|51.46 / 52.28|61.31 / 56.62| STS-B转为2分类 |
| CosineMSELoss |46.89 / 45.86|72.27 / 71.35|75.29 / 77.19|54.92 / 54.35|81.64 / 77.76| STS-B标准化到0-1 |
# 4. 关系提取
-
[
百度关系提取数据集
](
http://ai.baidu.com/broad/download?dataset=sked
)
| solution | f1 | comment |
| ---- | ---- | ---- |
| CasRel | 81.87 | |
| gplinker | 81.88 | |
| tplinker | 69.00 | seq_len=64 |
| tplinker_plus | 79.30 | seq_len=64 |
# 5. 文本生成
-
[
CSL数据集
](
https://github.com/CLUEbenchmark/CLGE
)
| solution | Rouge-L | Rouge-1 | Rouge-2 | BLEU | comment |
| ---- | ---- | ---- | ---- | ---- | ---- |
|bert+unlim| 58.42 | 61.77 | 49.10 | 37.74 | ---- |
| bart | 59.67 | 63.12 | 50.44 | 39.20 | ---- |
| mt5 | 61.34 | 64.51 | 52.59 | 41.98 | ---- |
|t5_pegasus| 59.15 | 62.72 | 48.53 | 38.16 | ---- |
| uer_t5 | 60.31 | 63.78 | 50.87 | 38.76 | ---- |
\ No newline at end of file
examples/README.md
0 → 100644
View file @
66a1d0d0
## example简介
### 基础测试
-
[
basic_extract_features.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_extract_features.py
)
:测试BERT对句子的编码序列。
-
[
basic_gibbs_sampling_via_mlm.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_gibbs_sampling_via_mlm.py
)
:利用BERT+Gibbs采样进行文本随机生成,参考
[
这里
](
https://kexue.fm/archives/8119
)
。
-
[
basic_language_model_nezha_gen_gpt.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_nezha_gen_gpt.py
)
:测试
[
GPTBase(又叫NEZHE-GEN)
](
https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-Gen-TensorFlow
)
的生成效果。
-
[
basic_make_uncased_model_cased.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_make_uncased_model_cased.py
)
:通过简单修改词表,使得不区分大小写的模型有区分大小写的能力。
-
[
basic_masked_language_model.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_masked_language_model.py
)
:测试BERT的MLM模型效果。
-
[
basic_language_model_GAU_alpha.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_GAU_alpha.py
)
:测试
[
GAU-alpha
](
https://github.com/ZhuiyiTechnology/GAU-alpha
)
的MLM模型效果。
-
[
basic_masked_language_model_roformer.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_masked_language_model_roformer.py
)
:测试roformer的MLM模型效果。
-
[
basic_language_model_CDial_GPT.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_CDial_GPT.py
)
:测试
[
CDial_GPT
](
https://github.com/thu-coai/CDial-GPT
)
的对话生成效果。
-
[
basic_language_model_gpt2_ml.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_gpt2_ml.py
)
:测试
[
gpt2-ml
](
https://github.com/imcaspar/gpt2-ml
)
的的生成效果。
-
[
basic_language_model_cpm_lm.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_cpm_lm.py
)
:测试
[
CPM-Generate
](
https://github.com/TsinghuaAI/CPM-Generate
)
的的生成效果。
-
[
basic_language_model_t5.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_t5.py
)
:测试
[
uer-t5-small
](
https://huggingface.co/uer/t5-small-chinese-cluecorpussmall
)
的生成效果。
-
[
basic_language_model_simbert.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_simbert.py
)
:测试
[
simbert
](
https://github.com/ZhuiyiTechnology/simbert
)
和
[
roformer-sim
](
https://github.com/ZhuiyiTechnology/roformer-sim
)
的生成效果和句子相似度效果。
-
[
basic_simple_web_serving_simbert.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_simple_web_serving_simbert.py
)
: 测试自带的WebServing(将模型转化为Web接口)。
-
[
basic_language_model_transformer_xl.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_transformer_xl.py
)
: 测试transformer_xl模型,做了一些简化,仅有英文预训练模型。
-
[
basic_language_model_xlnet.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_xlnet.py
)
: 测试xlnet模型。
-
[
basic_language_model_nezha_gpt_dialog.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/basic/basic_language_model_nezha_gpt_dialog.py
)
: 测试
[
nezha_gpt_dialog
](
https://kexue.fm/archives/7718
)
。
### 文本分类
-
[
task_sentence_similarity_lcqmc.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentence_similarity_lcqmc.py
)
:句子对分类任务。
-
[
task_sentiment_classification.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification.py
)
:情感分类任务,bert做简单文本分类
-
[
task_sentiment_classification_albert.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_albert.py
)
:情感分类任务,加载ALBERT模型。
-
[
task_sentiment_classification_xlnet.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_xlnet.py
)
:情感分类任务,加载XLNET模型。
-
[
task_sentiment_classification_hierarchical_position.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_hierarchical_position.py
)
:情感分类任务,层次分解位置编码做长文本的初始化
-
[
task_sentiment_classification_nezha.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_nezha.py
)
:情感分类任务,加载nezha模型
-
[
task_sentiment_classification_roformer.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_roformer.py
)
:情感分类任务,加载roformer权重
-
[
task_sentiment_classification_roformer_v2.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_roformer_v2.py
)
:情感分类任务,加载roformer_v2权重
-
[
task_sentiment_classification_electra.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_electra.py
)
:情感分类任务,加载electra权重
-
[
task_sentiment_classification_GAU_alpha.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_GAU_alpha.py
)
:情感分类任务,加载GAU-alpha权重
-
[
task_sentiment_classification_PET.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_PET.py
)
:情感分类项目,
[
Pattern-Exploiting-Training
](
https://github.com/bojone/Pattern-Exploiting-Training
)
,
[
bert4keras示例
](
https://github.com/bojone/Pattern-Exploiting-Training
)
-
[
task_sentiment_classification_P_tuning.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/task_sentiment_classification_P_tuning.py
)
:情感分类项目,
[
P-tuning
](
https://github.com/THUDM/P-tuning
)
,
[
bert4keras示例
](
https://github.com/bojone/P-tuning
)
-
[
Sohu_2022_ABSA
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/Sohu_2022_ABSA
)
:搜狐2022实体情感分类Top1方案复现和自己的baseline
-
[
Tianchi_News_Classification
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_classfication/Tianchi_News_Classification
)
:天池零基础入门NLP-新闻分类Top1方案复现
### 序列标注
-
[
task_sequence_labeling_ner_efficient_global_pointer.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_efficient_global_pointer.py
)
:ner例子,efficient_global_pointer的pytorch实现
-
[
task_sequence_labeling_ner_global_pointer.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_global_pointer.py
)
:ner例子,global_pointer的pytorch实现
-
[
task_sequence_labeling_ner_crf.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_crf.py
)
:ner例子,bert+crf
-
[
task_sequence_labeling_ner_crf_freeze.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_crf_freeze.py
)
:ner例子,bert+crf, 一种是用数据集来生成crf权重,第二种是来初始化
-
[
task_sequence_labeling_ner_cascade_crf.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_cascade_crf.py
)
:ner例子,bert+crf+级联
-
[
task_sequence_labeling_ner_crf_add_posseg.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_crf_add_posseg.py
)
:ner例子,bert+crf,词性作为输入
-
[
task_sequence_labeling_ner_tplinker_plus.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_tplinker_plus.py
)
:ner例子,改造了关系抽取
[
TPLinker
](
https://github.com/131250208/TPlinker-joint-extraction
)
-
[
task_sequence_labeling_ner_mrc.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_mrc.py
)
:ner例子,
[
mrc方案
](
https://github.com/z814081807/DeepNER
)
,用阅读理解的方式来做
-
[
task_sequence_labeling_ner_span.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_span.py
)
:ner例子,
[
span方案
](
https://github.com/z814081807/DeepNER
)
,用半指针-半标注方式来做
-
[
uie
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/uie
)
:ner例子,
[
uie方案
](
https://github.com/universal-ie/UIE
)
,prompt+mrc模型结构
### 文本表示
-
[
task_sentence_embedding_unsup_bert_whitening.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_bert_whitening.py
)
:参考
[
bert_whitening
](
https://github.com/bojone/BERT-whitening
)
-
[
task_sentence_embedding_unsup_CT.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_CT.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_unsup_CT_In-Batch_Negatives.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_CT_In-Batch_Negatives.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_unsup_SimCSE.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_SimCSE.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
和
[
科学空间版中文测试
](
https://kexue.fm/archives/8348
)
-
[
task_sentence_embedding_unsup_ESimCSE.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_ESimCSE.py
)
:参考
[
ESimCSE论文
](
https://arxiv.org/pdf/2109.04380.pdf
)
和
[
第三方实现
](
https://github.com/shuxinyin/SimCSE-Pytorch
)
-
[
task_sentence_embedding_unsup_TSDAE.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_TSDAE.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_unsup_PromptBert.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_unsup_PromptBert.py
)
:
[
PromptBert
](
https://github.com/kongds/Prompt-BERT
)
方式
-
[
task_sentence_embedding_sup_ContrastiveLoss.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_sup_ContrastiveLoss.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_sup_CosineMSELoss.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_sup_CosineMSELoss.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_sup_concat_CrossEntropyLoss.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_sup_concat_CrossEntropyLoss.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_sup_MultiNegtiveRankingLoss.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_sup_MultiNegtiveRankingLoss.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_sup_CoSENT.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_sup_CoSENT.py
)
:参考
[
CoSENT
](
https://kexue.fm/archives/8847
)
-
[
task_sentence_embedding_DimensionalityReduction.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_DimensionalityReduction.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
task_sentence_embedding_model_distillation.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/task_sentence_embedding_model_distillation.py
)
:参考
[
SentenceTransformer
](
https://www.sbert.net/index.html
)
-
[
FinanceFAQ
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/sentence_embedding/FinanceFAQ
)
:金融领域FAQ两阶段(召回+排序)pipline
### 关系提取
-
[
task_relation_extraction_CasRel.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction_CasRel.py
)
:结合BERT以及自行设计的“半指针-半标注”结构来做
[
关系抽取
](
https://kexue.fm/archives/7161
)
。
-
[
task_relation_extraction_gplinker.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction_gplinker.py
)
:结合GlobalPointer做关系抽取
[
GPLinker
](
https://kexue.fm/archives/8888
)
。
-
[
task_relation_extraction_tplinker.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction_tplinker.py
)
:tplinker关系抽取
[
TPLinker
](
https://github.com/131250208/TPlinker-joint-extraction
)
。
-
[
task_relation_extraction_tplinker_plus.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction_tplinker_plus.py
)
:tplinker关系抽取
[
TPLinkerPlus
](
https://github.com/131250208/TPlinker-joint-extraction
)
。
### 文本生成
-
[
task_seq2seq_autotitle_unilm.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_autotitle_unilm.py
)
:通过
[
UniLM
](
https://kexue.fm/archives/6933
)
式的Seq2Seq模型来做新闻标题生成。
-
[
task_seq2seq_autotitle_csl_bart.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_autotitle_csl_bart.py
)
:通过BART来做新闻标题生成
-
[
task_seq2seq_autotitle_csl_uer_t5.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_autotitle_csl_uer_t5.py
)
:通过T5来做新闻标题生成,用的
[
uer-t5-small
](
https://huggingface.co/uer/t5-small-chinese-cluecorpussmall
)
-
[
task_seq2seq_autotitle_csl.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_autotitle_csl.py
)
:通过
[
UniLM
](
https://kexue.fm/archives/6933
)
式的Seq2Seq模型来做论文标题生成。
-
[
task_seq2seq_autotitle_csl_mt5.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_autotitle_csl_mt5.py
)
:通过
[
google_mt
](
https://huggingface.co/google/mt5-base
)
的Seq2Seq模型来做论文标题生成。
-
[
task_question_answer_generation_by_seq2seq.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_question_answer_generation_by_seq2seq.py
)
:通过
[
UniLM
](
https://kexue.fm/archives/6933
)
式的Seq2Seq模型来做
[
问答对自动构建
](
https://kexue.fm/archives/7630
)
,属于自回归文本生成。
-
[
task_reading_comprehension_by_mlm.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_reading_comprehension_by_mlm.py
)
:通过MLM模型来做
[
阅读理解问答
](
https://kexue.fm/archives/7148
)
,属于简单的非自回归文本生成。
-
[
task_reading_comprehension_by_seq2seq.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_reading_comprehension_by_seq2seq.py
)
:通过
[
UniLM
](
https://kexue.fm/archives/6933
)
式的Seq2Seq模型来做
[
阅读理解问答
](
https://kexue.fm/archives/7115
)
,属于自回归文本生成。
-
[
task_seq2seq_simbert.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_simbert.py
)
:相似问生成,数据增广,参考
[
SimBERT
](
https://kexue.fm/archives/7427
)
-
[
task_seq2seq_ape210k_math_word_problem.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_seq2seq_ape210k_math_word_problem.py
)
:bert+unilm硬刚小学数学题,参考
[
博客
](
https://kexue.fm/archives/7809
)
-
[
task_kgclue_seq2seq.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/seq2seq/task_kgclue_seq2seq.py
)
:seq2seq+前缀树,参考
[
博客
](
https://kexue.fm/archives/8802
)
### 训练Trick
-
[
task_sentiment_adversarial_training.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_adversarial_training.py
)
:通过对抗训练,虚拟对抗训练,梯度惩罚等措施来提升分类效果。
-
[
task_sentiment_virtual_adversarial_training.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_virtual_adversarial_training.py
)
:通过半监督的虚拟对抗训练等措施来提升分类效果。
-
[
task_sentiment_UDA.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_UDA.py
)
:通过
[
UDA
](
https://arxiv.org/abs/1904.12848
)
半监督学习提升分类效果,在原来Losss上加一致性损失。
-
[
task_sentiment_mixup.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_mixup.py
)
:通过
[
Mixup
](
https://github.com/vikasverma1077/manifold_mixup
)
提升模型泛化性能。
-
[
task_sentiment_exponential_moving_average.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_exponential_moving_average.py
)
:EMA指数滑动平均
-
[
task_sentiment_TemporalEnsembling.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_TemporalEnsembling.py
)
:通过
[
TemporalEnsembling官方项目
](
https://github.com/s-laine/tempens
)
和
[
pytorch第三方实现
](
https://github.com/ferretj/temporal-ensembling
)
提升模型泛化性能。
-
[
task_sentiment_R-Drop.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_sentiment_R-Drop.py
)
:通过
[
R-Drop
](
https://github.com/dropreg/R-Drop
)
提升分类效果,可以视为用dropout加噪下的UDA。
-
[
task_amp.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_amp.py
)
:Pytorch的amp混合精度训练
-
[
task_data_parallel.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_data_parallel.py
)
:DataParallel模式的多GPU训练方式
-
[
task_distributed_data_parallel.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/training_trick/task_distributed_data_parallel.py
)
:DistributedDataParallel模式的多GPU训练方式
### 预训练
-
[
roberta_pretrain
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain/roberta_pretrain
)
:roberta的mlm预训练,数据生成代码和训练代码
-
[
simbert_v2_pretrain
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain/simbert_v2_pretrain
)
:相似问生成,数据增广,三个步骤:1-
[
弱监督
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain/simbert_v2_pretrain/simbert_v2_stage1.py
)
,2-
[
蒸馏
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain/simbert_v2_pretrain/simbert_v2_stage2.py
)
,3-
[
有监督
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/pretrain/simbert_v2_pretrain/simbert_v2_supervised.py
)
,参考
[
SimBERT-V2
](
https://kexue.fm/archives/8454
)
### 其他
-
[
task_conditional_language_model.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_conditional_language_model.py
)
:结合BERT+
[
ConditionalLayerNormalization
](
https://kexue.fm/archives/7124
)
做条件语言模型。
-
[
task_language_model.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_language_model.py
)
:加载BERT的预训练权重做无条件语言模型,效果上等价于GPT。
-
[
task_iflytek_bert_of_theseus.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_iflytek_bert_of_theseus.py
)
:通过
[
BERT-of-Theseus
](
https://kexue.fm/archives/7575
)
来进行模型压缩。
-
[
task_language_model_chinese_chess.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_language_model_chinese_chess.py
)
:用GPT的方式下中国象棋,过程请参考
[
博客
](
https://kexue.fm/archives/7877
)
。
-
[
task_custom_fit_progress.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_custom_fit_progress.py
)
:教程,自定义训练过程fit函数(集成了训练进度条展示),可用于满足如半精度,梯度裁剪等高阶需求。
-
[
task_load_transformers_model.py
](
https://github.com/Tongjilibo/bert4torch/blob/master/examples/others/task_load_transformers_model.py
)
:教程,加载transformer包中模型,可以使用bert4torch中继承的对抗训练等trick。
## 用到的数据集
| 数据集名称 | 用途 | 下载链接 |
| ---- | ---- | ---- |
|人民日报数据集|实体识别|
[
china-people-daily-ner-corpus
](
http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
)
|百度关系抽取|关系抽取|
[
BD_Knowledge_Extraction
](
http://ai.baidu.com/broad/download?dataset=sked
)
|Sentiment|情感分类|
[
Sentiment
](
https://github.com/bojone/bert4keras/blob/master/examples/datasets/sentiment.zip
)
|THUCNews|文本分类、文本生成|
[
THUCNews
](
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews
)
|ATEC| 文本相似度 |
[
ATEC
](
https://github.com/IceFlameWorm/NLP_Datasets/tree/master/ATEC
)
|BQ| 文本相似度 |
[
BQ
](
http://icrc.hitsz.edu.cn/info/1037/1162.htm
)
|LCQMC| 文本相似度 |
[
LCQMC
](
http://icrc.hitsz.edu.cn/Article/show/171.html
)
|PAWSX| 文本相似度 |
[
PAWSX
](
https://arxiv.org/abs/1908.11828
)
|STS-B| 文本相似度 |
[
STS-B
](
https://github.com/pluto-junzeng/CNSD
)
\ No newline at end of file
examples/basic/basic_extract_features.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 测试代码可用性: 提取特征
import
torch
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
root_model_path
=
"F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path
=
root_model_path
+
"/vocab.txt"
config_path
=
root_model_path
+
"/bert_config.json"
checkpoint_path
=
root_model_path
+
'/pytorch_model.bin'
tokenizer
=
Tokenizer
(
vocab_path
,
do_lower_case
=
True
)
# 建立分词器
model
=
build_transformer_model
(
config_path
,
checkpoint_path
)
# 建立模型,加载权重
# 编码测试
token_ids
,
segment_ids
=
tokenizer
.
encode
(
u
'语言模型'
)
token_ids
,
segment_ids
=
torch
.
tensor
([
token_ids
]),
torch
.
tensor
([
segment_ids
])
print
(
'
\n
===== predicting =====
\n
'
)
model
.
eval
()
with
torch
.
no_grad
():
print
(
model
([
token_ids
,
segment_ids
])[
0
])
"""
输出:
[[[-0.63251007 0.2030236 0.07936534 ... 0.49122632 -0.20493352
0.2575253 ]
[-0.7588351 0.09651865 1.0718756 ... -0.6109694 0.04312154
0.03881441]
[ 0.5477043 -0.792117 0.44435206 ... 0.42449304 0.41105673
0.08222899]
[-0.2924238 0.6052722 0.49968526 ... 0.8604137 -0.6533166
0.5369075 ]
[-0.7473459 0.49431565 0.7185162 ... 0.3848612 -0.74090636
0.39056838]
[-0.8741375 -0.21650358 1.338839 ... 0.5816864 -0.4373226
0.56181806]]]
"""
\ No newline at end of file
examples/basic/basic_gibbs_sampling_via_mlm.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 测试代码可用性: 结合MLM的Gibbs采样
from
tqdm
import
tqdm
import
numpy
as
np
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
import
torch
import
torch.nn
as
nn
root_model_path
=
"F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path
=
root_model_path
+
"/vocab.txt"
config_path
=
root_model_path
+
"/bert_config.json"
checkpoint_path
=
root_model_path
+
'/pytorch_model.bin'
tokenizer
=
Tokenizer
(
vocab_path
,
do_lower_case
=
True
)
# 建立分词器
model
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
with_mlm
=
'softmax'
)
# 建立模型,加载权重
sentences
=
[]
init_sent
=
u
'科学技术是第一生产力。'
# 给定句子或者None
minlen
,
maxlen
=
8
,
32
steps
=
10000
converged_steps
=
1000
vocab_size
=
tokenizer
.
_vocab_size
if
init_sent
is
None
:
length
=
np
.
random
.
randint
(
minlen
,
maxlen
+
1
)
tokens
=
[
'[CLS]'
]
+
[
'[MASK]'
]
*
length
+
[
'[SEP]'
]
token_ids
=
tokenizer
.
tokens_to_ids
(
tokens
)
segment_ids
=
[
0
]
*
len
(
token_ids
)
else
:
token_ids
,
segment_ids
=
tokenizer
.
encode
(
init_sent
)
length
=
len
(
token_ids
)
-
2
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
model
.
to
(
device
)
model
.
eval
()
with
torch
.
no_grad
():
for
_
in
tqdm
(
range
(
steps
),
desc
=
'Sampling'
):
# Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
i
=
np
.
random
.
choice
(
length
)
+
1
token_ids
[
i
]
=
tokenizer
.
_token_mask_id
token_ids_tensor
,
segment_ids_tensor
=
torch
.
tensor
([
token_ids
],
device
=
device
),
torch
.
tensor
([
segment_ids
],
device
=
device
)
_
,
probas
=
model
([
token_ids_tensor
,
segment_ids_tensor
])
probas
=
probas
[
0
,
i
]
token
=
np
.
random
.
choice
(
vocab_size
,
p
=
probas
.
cpu
().
numpy
())
token_ids
[
i
]
=
token
sentences
.
append
(
tokenizer
.
decode
(
token_ids
))
print
(
u
'部分随机采样结: '
)
for
_
in
range
(
10
):
print
(
np
.
random
.
choice
(
sentences
[
converged_steps
:]))
examples/basic/basic_language_model_CDial_GPT.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 基本测试:中文GPT模型,base版本,CDial-GPT版
# 项目链接:https://github.com/thu-coai/CDial-GPT
# 参考项目:https://github.com/bojone/CDial-GPT-tf
# 权重需转换后方可加载,转换脚本见convert_script文件夹
import
torch
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
from
bert4torch.snippets
import
AutoRegressiveDecoder
config_path
=
'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_vocab.txt'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
tokenizer
=
Tokenizer
(
dict_path
,
do_lower_case
=
True
)
# 建立分词器
speakers
=
[
tokenizer
.
token_to_id
(
'[speaker1]'
),
tokenizer
.
token_to_id
(
'[speaker2]'
)]
# config中设置shared_segment_embeddings=True,segment embedding用word embedding的权重生成
model
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
model
=
'gpt'
,
).
to
(
device
)
# 建立模型,加载权重
class
ChatBot
(
AutoRegressiveDecoder
):
"""基于随机采样的闲聊回复
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'probas'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
,
segment_ids
=
inputs
curr_segment_ids
=
torch
.
zeros_like
(
output_ids
)
+
token_ids
[
0
,
-
1
]
token_ids
=
torch
.
cat
([
token_ids
,
output_ids
],
1
)
segment_ids
=
torch
.
cat
([
segment_ids
,
curr_segment_ids
],
1
)
logits
=
model
.
predict
([
token_ids
,
segment_ids
])
return
logits
[:,
-
1
,
:]
def
response
(
self
,
texts
,
n
=
1
,
topk
=
5
):
token_ids
=
[
tokenizer
.
_token_start_id
,
speakers
[
0
]]
segment_ids
=
[
tokenizer
.
_token_start_id
,
speakers
[
0
]]
for
i
,
text
in
enumerate
(
texts
):
ids
=
tokenizer
.
encode
(
text
)[
0
][
1
:
-
1
]
+
[
speakers
[(
i
+
1
)
%
2
]]
token_ids
.
extend
(
ids
)
segment_ids
.
extend
([
speakers
[
i
%
2
]]
*
len
(
ids
))
segment_ids
[
-
1
]
=
speakers
[(
i
+
1
)
%
2
]
results
=
self
.
random_sample
([
token_ids
,
segment_ids
],
n
,
topk
)
# 基于随机采样
return
tokenizer
.
decode
(
results
[
0
].
cpu
().
numpy
())
chatbot
=
ChatBot
(
start_id
=
None
,
end_id
=
tokenizer
.
_token_end_id
,
maxlen
=
32
,
device
=
device
)
print
(
chatbot
.
response
([
u
'别爱我没结果'
,
u
'你这样会失去我的'
,
u
'失去了又能怎样'
]))
"""
回复是随机的,例如:你还有我 | 那就不要爱我 | 你是不是傻 | 等等。
"""
\ No newline at end of file
examples/basic/basic_language_model_GAU_alpha.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 基础测试:GAU_alpha的mlm预测,和bert4keras版本比对一致
# 测试中长文本效果明显高于短文本效果
# 博客:https://kexue.fm/archives/9052
# 权重转换脚本:./convert_script/convert_GAU_alpha.py
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
import
torch
# 加载模型,请更换成自己的路径
config_path
=
'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/bert_config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/vocab.txt'
# 建立分词器
tokenizer
=
Tokenizer
(
dict_path
,
do_lower_case
=
True
)
model
=
build_transformer_model
(
config_path
,
checkpoint_path
,
model
=
'gau_alpha'
,
with_mlm
=
'softmax'
)
# 建立模型,加载权重
token_ids
,
segments_ids
=
tokenizer
.
encode
(
"近期正是上市公司财报密集披露的时间,但有多家龙头公司的业绩令投资者失望"
)
token_ids
[
5
]
=
token_ids
[
6
]
=
tokenizer
.
_token_mask_id
print
(
''
.
join
(
tokenizer
.
ids_to_tokens
(
token_ids
)))
tokens_ids_tensor
=
torch
.
tensor
([
token_ids
])
segment_ids_tensor
=
torch
.
tensor
([
segments_ids
])
# 需要传入参数with_mlm
model
.
eval
()
with
torch
.
no_grad
():
_
,
probas
=
model
([
tokens_ids_tensor
,
segment_ids_tensor
])
result
=
torch
.
argmax
(
probas
[
0
,
5
:
7
],
dim
=-
1
).
numpy
()
print
(
tokenizer
.
decode
(
result
))
examples/basic/basic_language_model_cpm_lm.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 基本测试:清华开源的中文GPT2模型(26亿参数)
# 项目链接:https://github.com/TsinghuaAI/CPM-Generate
# 博客介绍:https://kexue.fm/archives/7912
# 权重需转换后方可加载,转换脚本见convert_script文件夹
import
numpy
as
np
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
SpTokenizer
from
bert4torch.snippets
import
AutoRegressiveDecoder
import
torch
import
jieba
jieba
.
initialize
()
# 模型路径
config_path
=
'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/bert4torch_config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/bert4torch_pytorch_model.bin'
spm_path
=
'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/chinese_vocab.model'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
def
pre_tokenize
(
text
):
"""分词前处理函数,'
\n
'替换成'▃', ' '替换成'▂'
"""
return
[
w
.
replace
(
' '
,
u
'
\u2582
'
).
replace
(
'
\n
'
,
u
'
\u2583
'
)
for
w
in
jieba
.
cut
(
text
,
cut_all
=
False
)
]
tokenizer
=
SpTokenizer
(
spm_path
,
token_start
=
None
,
token_end
=
None
,
pre_tokenize
=
pre_tokenize
,
token_translate
=
{
u
'
\u2583
'
:
'<cls>'
}
# '\n'替换成<cls>
)
# 建立分词器
model
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
model
=
'gpt2'
,
segment_vocab_size
=
0
).
to
(
device
)
# 建立模型,加载权重
class
TextExpansion
(
AutoRegressiveDecoder
):
"""基于随机采样的文本续写
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'probas'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
=
torch
.
cat
([
inputs
[
0
],
output_ids
],
1
)
logits
=
model
.
predict
([
token_ids
])
return
logits
[:,
-
1
,
:]
def
generate
(
self
,
text
,
n
=
1
,
topp
=
0.95
,
temperature
=
1
):
"""输出结果会有一定的随机性,如果只关心Few Shot效果,
可以考虑将解码方式换为beam search。
"""
token_ids
,
_
=
tokenizer
.
encode
(
text
)
results
=
self
.
random_sample
([
token_ids
],
n
,
topp
=
topp
,
temperature
=
temperature
)
# 基于随机采样
results
=
[
token_ids
+
[
int
(
i
)
for
i
in
ids
.
cpu
().
numpy
()]
for
ids
in
results
]
texts
=
[
tokenizer
.
decode
(
ids
)
for
ids
in
results
]
return
[
self
.
post_replace
(
text
)
for
text
in
texts
]
def
post_replace
(
self
,
text
):
for
s
,
t
in
[(
' '
,
''
),
(
u
'
\u2582
'
,
' '
),
(
u
'
\u2583
'
,
'
\n
'
)]:
text
=
text
.
replace
(
s
,
t
)
return
text
text_expansion
=
TextExpansion
(
start_id
=
None
,
end_id
=
3
,
# 3是<cls>,也是换行符
maxlen
=
16
,
device
=
device
)
# 常识推理
# 本例输出:北京
query
=
u
"""
美国的首都是华盛顿
法国的首都是巴黎
日本的首都是东京
中国的首都是
"""
print
(
text_expansion
.
generate
(
query
[
1
:
-
1
],
1
)[
0
])
# 单词翻译
# 本例输出:bird
query
=
u
"""
狗 dog
猫 cat
猪 pig
鸟
"""
print
(
text_expansion
.
generate
(
query
[
1
:
-
1
],
1
)[
0
])
# 主语抽取
# 本例输出:杨振宁
query
=
u
"""
从1931年起,华罗庚在清华大学边学习边工作 华罗庚
在一间简陋的房间里,陈景润攻克了“哥德巴赫猜想” 陈景润
在这里,丘成桐得到IBM奖学金 丘成桐
杨振宁在粒子物理学、统计力学和凝聚态物理等领域作出里程碑性贡献
"""
print
(
text_expansion
.
generate
(
query
[
1
:
-
1
],
1
)[
0
])
# 三元组抽取
# 本例输出:张红,体重,140斤
query
=
u
"""
姚明的身高是211cm,是很多人心目中的偶像。 ->姚明,身高,211cm
毛泽东是绍兴人,早年在长沙读书。->毛泽东,出生地,绍兴
虽然周杰伦在欧洲办的婚礼,但是他是土生土长的中国人->周杰伦,国籍,中国
小明出生于武汉,但是却不喜欢在武汉生成,长大后去了北京。->小明,出生地,武汉
吴亦凡是很多人的偶像,但是他却是加拿大人,另很多人失望->吴亦凡,国籍,加拿大
武耀的生日在5月8号,这一天,大家都为他庆祝了生日->武耀,生日,5月8号
《青花瓷》是周杰伦最得意的一首歌。->周杰伦,作品,《青花瓷》
北京是中国的首都。->中国,首都,北京
蒋碧的家乡在盘龙城,毕业后去了深圳工作。->蒋碧,籍贯,盘龙城
上周我们和王立一起去了他的家乡云南玩昨天才回到了武汉。->王立,籍贯,云南
昨天11月17号,我和朋友一起去了海底捞,期间服务员为我的朋友刘章庆祝了生日。->刘章,生日,11月17号
张红的体重达到了140斤,她很苦恼。->
"""
print
(
text_expansion
.
generate
(
query
[
1
:
-
1
],
1
)[
0
])
examples/basic/basic_language_model_gpt2_ml.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 基本测试:gpt2_ml的效果测试
# 项目链接(tf版本):https://github.com/imcaspar/gpt2-ml
# 权重需转换后方可加载,转换脚本见convert_script文件夹
import
torch
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
from
bert4torch.snippets
import
AutoRegressiveDecoder
config_path
=
'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/bert4torch_config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/bert4torch_pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/vocab.txt'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
tokenizer
=
Tokenizer
(
dict_path
,
token_start
=
None
,
token_end
=
None
,
do_lower_case
=
True
)
# 建立分词器
model
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
model
=
'gpt2_ml'
,
segment_vocab_size
=
0
).
to
(
device
)
# 建立模型,加载权重
class
ArticleCompletion
(
AutoRegressiveDecoder
):
"""基于随机采样的文章续写
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'probas'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
=
torch
.
cat
([
inputs
[
0
],
output_ids
],
1
)
logits
=
model
.
predict
([
token_ids
])
return
logits
[:,
-
1
,
:]
def
generate
(
self
,
text
,
n
=
1
,
topp
=
0.95
):
token_ids
,
_
=
tokenizer
.
encode
(
text
)
results
=
self
.
random_sample
([
token_ids
],
n
,
topp
=
topp
)
# 基于随机采样
return
[
text
+
tokenizer
.
decode
(
ids
.
cpu
().
numpy
())
for
ids
in
results
]
article_completion
=
ArticleCompletion
(
start_id
=
None
,
end_id
=
511
,
# 511是中文句号
maxlen
=
256
,
minlen
=
128
,
device
=
device
)
for
text
in
[
u
'今天天气不错'
,
u
'双十一'
,
u
'科学空间'
]:
print
(
article_completion
.
generate
(
text
))
"""
部分结果:
>>> article_completion.generate(u'今天天气不错')
[u'今天天气不错。昨天的天气是多云到晴的天气,今天的天气还不错,不会太冷。明后两天天气还是比较好的。不过今天的天气比较闷热,最高温度在30℃左右,明后两天天气会更加热。预计今天的最高温度为30℃,明后两天的最 高温度为32℃左右,今天的最高气温将在30℃左右。(记者李莉)。新华网重庆频道诚邀广大网友投稿,您可以用相机或手机记录下身边的感人故事,精彩瞬间。请将作者、拍摄时间、地点和简要说明连同照片发给我们,我们将精选其中的好图、美图在页面上展示,让所有新华网友共赏。[投稿] 。本报讯(记者陈敏华) 今年上半年,重庆市各级公安机关在全力抓好']
>>> article_completion.generate(u'双十一')
[u'双十一大是中国共产党在新的历史起点上召开的一次十分重要的代表大会, 是全面落实科学发展观、推进中国特色社会主义伟大事业的一次重要会议。会议的召开, 是党和政府对新世纪新阶段我国改革开放和社会主义现代化建设 事业的新的历史任务的一次重要总动员, 必将对我们党全面推进党的建']
>>> article_completion.generate(u'科学空间')
[u'科学空间站上的两个机器人在进入轨道后,一边在轨道上工作,一边用它们的身体和心脏在空间站上的一个大气层进行活动,以确保它们在进入地球之后不会因太阳风暴而受到影响;而另外一个机器人则在进入轨道的过程中,通 过机器人与地球上的大气层相互作用,使地球的大气层不断地向地球的大气层中转移,以使其能够在空间站上工作,并且使用它们的身体和心脏来完成它们的各种任务。']
"""
examples/basic/basic_language_model_nezha_gen_gpt.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 基本测试:中文GPT模型,base版本,华为开源的
# 权重链接: https://pan.baidu.com/s/1-FB0yl1uxYDCGIRvU1XNzQ 提取码: xynn,这里使用的是转pytorch后的模型文件
# 参考项目:https://github.com/bojone/chinese-gen
import
torch
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
from
bert4torch.snippets
import
AutoRegressiveDecoder
config_path
=
'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/vocab.txt'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
tokenizer
=
Tokenizer
(
dict_path
,
do_lower_case
=
True
)
# 建立分词器
model
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
segment_vocab_size
=
0
,
# 去掉segmeng_ids输入
application
=
'lm'
,
).
to
(
device
)
# 建立模型,加载权重
class
ArticleCompletion
(
AutoRegressiveDecoder
):
"""基于随机采样的文章续写
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'logits'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
=
torch
.
cat
([
inputs
[
0
],
output_ids
],
1
)
_
,
mlm_scores
=
model
.
predict
([
token_ids
])
return
mlm_scores
[:,
-
1
,
:]
def
generate
(
self
,
text
,
n
=
1
,
topp
=
0.95
):
token_ids
=
tokenizer
.
encode
(
text
)[
0
][:
-
1
]
results
=
self
.
random_sample
([
token_ids
],
n
,
topp
=
topp
)
# 基于随机采样
return
[
text
+
tokenizer
.
decode
(
ids
.
cpu
().
numpy
())
for
ids
in
results
]
article_completion
=
ArticleCompletion
(
start_id
=
None
,
end_id
=
511
,
# 511是中文句号
maxlen
=
256
,
minlen
=
128
,
device
=
device
)
print
(
article_completion
.
generate
(
u
'今天天气不错'
))
"""
部分结果:
>>> article_completion.generate(u'今天天气不错')
[u'今天天气不错。昨天的天气是多云到晴的天气,今天的天气还不错,不会太冷。明后两天天气还是比较好的。不过今天的天气比较闷热,最高温度在30℃左右,明后两天天气会更加热。预计今天的最高温度为30℃,明后两天的最 高温度为32℃左右,今天的最高气温将在30℃左右。(记者李莉)。新华网重庆频道诚邀广大网友投稿,您可以用相机或手机记录下身边的感人故事,精彩瞬间。请将作者、拍摄时间、地点和简要说明连同照片发给我们,我们将精选其中的好图、美图在页面上展示,让所有新华网友共赏。[投稿] 。本报讯(记者陈敏华) 今年上半年,重庆市各级公安机关在全力抓好']
>>> article_completion.generate(u'双十一')
[u'双十一大是中国共产党在新的历史起点上召开的一次十分重要的代表大会, 是全面落实科学发展观、推进中国特色社会主义伟大事业的一次重要会议。会议的召开, 是党和政府对新世纪新阶段我国改革开放和社会主义现代化建设 事业的新的历史任务的一次重要总动员, 必将对我们党全面推进党的建']
>>> article_completion.generate(u'科学空间')
[u'科学空间站上的两个机器人在进入轨道后,一边在轨道上工作,一边用它们的身体和心脏在空间站上的一个大气层进行活动,以确保它们在进入地球之后不会因太阳风暴而受到影响;而另外一个机器人则在进入轨道的过程中,通 过机器人与地球上的大气层相互作用,使地球的大气层不断地向地球的大气层中转移,以使其能够在空间站上工作,并且使用它们的身体和心脏来完成它们的各种任务。']
"""
examples/basic/basic_language_model_nezha_gpt_dialog.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# NEZHA模型做闲聊任务,这里只提供了测试脚本
# 源项目:https://github.com/bojone/nezha_gpt_dialog
# 权重转换脚本见:https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_nezha_gpt_dialog.py
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
from
bert4torch.snippets
import
AutoRegressiveDecoder
import
torch
# nezha配置
config_path
=
'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/vocab.txt'
# 建立分词器
tokenizer
=
Tokenizer
(
dict_path
,
do_lower_case
=
True
)
# 建立并加载模型
model
=
build_transformer_model
(
config_path
,
checkpoint_path
,
model
=
'nezha'
,
application
=
'lm'
,
)
class
ChatBot
(
AutoRegressiveDecoder
):
"""基于随机采样对话机器人
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'logits'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
,
segment_ids
=
inputs
token_ids
=
torch
.
concat
([
token_ids
,
output_ids
],
1
)
curr_segment_ids
=
torch
.
ones_like
(
output_ids
)
-
segment_ids
[
0
,
-
1
]
segment_ids
=
torch
.
concat
([
segment_ids
,
curr_segment_ids
],
1
)
return
model
.
predict
([
token_ids
,
segment_ids
])[
-
1
][:,
-
1
]
def
response
(
self
,
texts
,
topk
=
5
):
token_ids
,
segment_ids
=
[
tokenizer
.
_token_start_id
],
[
0
]
for
i
,
text
in
enumerate
(
texts
):
ids
=
tokenizer
.
encode
(
text
)[
0
][
1
:]
token_ids
.
extend
(
ids
)
segment_ids
.
extend
([
i
%
2
]
*
len
(
ids
))
results
=
self
.
random_sample
([
token_ids
,
segment_ids
],
1
,
topk
)
return
tokenizer
.
decode
(
results
[
0
].
cpu
().
numpy
())
chatbot
=
ChatBot
(
start_id
=
None
,
end_id
=
tokenizer
.
_token_end_id
,
maxlen
=
32
)
print
(
chatbot
.
response
([
u
'别爱我没结果'
,
u
'你这样会失去我的'
,
u
'失去了又能怎样'
]))
"""
回复是随机的,例如:那你还爱我吗 | 不知道 | 爱情是不是不能因为一点小事就否定了 | 我会一直爱你,你一个人会很辛苦 | 等等。
"""
\ No newline at end of file
examples/basic/basic_language_model_simbert.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# SimBERT/RoFormer-Sim测试相似问生成效果,以及句子之间相似度效果
# 官方项目:https://github.com/ZhuiyiTechnology/simbert
# 官方项目:https://github.com/ZhuiyiTechnology/roformer-sim
import
torch
from
bert4torch.models
import
build_transformer_model
,
BaseModel
from
bert4torch.snippets
import
sequence_padding
,
AutoRegressiveDecoder
,
get_pool_emb
from
bert4torch.tokenizers
import
Tokenizer
,
load_vocab
# 基本信息
maxlen
=
32
choice
=
'simbert_v2'
# simbert simbert_v2
if
choice
==
'simbert'
:
args_model_path
=
"F:/Projects/pretrain_ckpt/simbert/[sushen_torch_base]--simbert_chinese_base"
args_model
=
'bert'
else
:
args_model_path
=
"F:/Projects/pretrain_ckpt/simbert/[sushen_torch_base]--roformer_chinese_sim_char_base"
args_model
=
'roformer'
# 加载simbert权重或roformer_v2
root_model_path
=
args_model_path
dict_path
=
root_model_path
+
"/vocab.txt"
config_path
=
root_model_path
+
"/config.json"
checkpoint_path
=
root_model_path
+
'/pytorch_model.bin'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# 加载并精简词表,建立分词器
token_dict
,
keep_tokens
=
load_vocab
(
dict_path
=
dict_path
,
simplified
=
True
,
startswith
=
[
'[PAD]'
,
'[UNK]'
,
'[CLS]'
,
'[SEP]'
],
)
tokenizer
=
Tokenizer
(
token_dict
,
do_lower_case
=
True
)
# 建立加载模型
class
Model
(
BaseModel
):
def
__init__
(
self
,
pool_method
=
'cls'
):
super
().
__init__
()
self
.
bert
=
build_transformer_model
(
config_path
=
config_path
,
checkpoint_path
=
checkpoint_path
,
with_pool
=
'linear'
,
model
=
args_model
,
application
=
'unilm'
,
keep_tokens
=
keep_tokens
)
self
.
pool_method
=
pool_method
def
forward
(
self
,
token_ids
,
segment_ids
):
hidden_state
,
pooler
,
seq_logit
=
self
.
bert
([
token_ids
,
segment_ids
])
sen_emb
=
get_pool_emb
(
hidden_state
,
pooler
,
token_ids
.
gt
(
0
).
long
(),
self
.
pool_method
)
return
seq_logit
,
sen_emb
model
=
Model
(
pool_method
=
'cls'
).
to
(
device
)
class
SynonymsGenerator
(
AutoRegressiveDecoder
):
"""seq2seq解码器
"""
@
AutoRegressiveDecoder
.
wraps
(
'logits'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
,
segment_ids
=
inputs
token_ids
=
torch
.
cat
([
token_ids
,
output_ids
],
1
)
segment_ids
=
torch
.
cat
([
segment_ids
,
torch
.
ones_like
(
output_ids
,
device
=
device
)],
1
)
seq_logit
,
_
=
model
.
predict
([
token_ids
,
segment_ids
])
return
seq_logit
[:,
-
1
,
:]
def
generate
(
self
,
text
,
n
=
1
,
topk
=
5
):
token_ids
,
segment_ids
=
tokenizer
.
encode
(
text
,
maxlen
=
maxlen
)
output_ids
=
self
.
random_sample
([
token_ids
,
segment_ids
],
n
,
topk
)
# 基于随机采样
return
[
tokenizer
.
decode
(
ids
.
cpu
().
numpy
())
for
ids
in
output_ids
]
synonyms_generator
=
SynonymsGenerator
(
start_id
=
None
,
end_id
=
tokenizer
.
_token_end_id
,
maxlen
=
maxlen
,
device
=
device
)
def
cal_sen_emb
(
text_list
):
'''输入text的list,计算sentence的embedding
'''
X
,
S
=
[],
[]
for
t
in
text_list
:
x
,
s
=
tokenizer
.
encode
(
t
)
X
.
append
(
x
)
S
.
append
(
s
)
X
=
torch
.
tensor
(
sequence_padding
(
X
),
dtype
=
torch
.
long
,
device
=
device
)
S
=
torch
.
tensor
(
sequence_padding
(
S
),
dtype
=
torch
.
long
,
device
=
device
)
_
,
Z
=
model
.
predict
([
X
,
S
])
return
Z
def
gen_synonyms
(
text
,
n
=
100
,
k
=
20
):
""""含义: 产生sent的n个相似句,然后返回最相似的k个。
做法:用seq2seq生成,并用encoder算相似度并排序。
效果:
>>> gen_synonyms(u'微信和支付宝哪个好?')
[
u'微信和支付宝,哪个好?',
u'微信和支付宝哪个好',
u'支付宝和微信哪个好',
u'支付宝和微信哪个好啊',
u'微信和支付宝那个好用?',
u'微信和支付宝哪个好用',
u'支付宝和微信那个更好',
u'支付宝和微信哪个好用',
u'微信和支付宝用起来哪个好?',
u'微信和支付宝选哪个好',
]
"""
r
=
synonyms_generator
.
generate
(
text
,
n
)
r
=
[
i
for
i
in
set
(
r
)
if
i
!=
text
]
# 不和原文相同
r
=
[
text
]
+
r
Z
=
cal_sen_emb
(
r
)
Z
/=
(
Z
**
2
).
sum
(
dim
=
1
,
keepdims
=
True
)
**
0.5
argsort
=
torch
.
matmul
(
Z
[
1
:],
-
Z
[
0
]).
argsort
()
return
[
r
[
i
+
1
]
for
i
in
argsort
[:
k
]]
if
__name__
==
'__main__'
:
choice
=
'generate'
# generate similarity
if
choice
==
'generate'
:
print
(
gen_synonyms
(
'我想去北京玩玩可以吗'
,
10
,
10
))
elif
choice
==
'similarity'
:
target_text
=
'我想去首都北京玩玩'
text_list
=
[
'我想去北京玩'
,
'北京有啥好玩的吗?我想去看看'
,
'好渴望去北京游玩啊'
]
Z
=
cal_sen_emb
([
target_text
]
+
text_list
)
Z
/=
(
Z
**
2
).
sum
(
dim
=
1
,
keepdims
=
True
)
**
0.5
similarity
=
torch
.
matmul
(
Z
[
1
:],
Z
[
0
])
for
i
,
line
in
enumerate
(
text_list
):
print
(
f
'cos_sim:
{
similarity
[
i
].
item
():.
4
f
}
, tgt_text: "
{
target_text
}
", cal_text: "
{
line
}
"'
)
else
:
model
.
load_weights
(
'./best_model.pt'
)
examples/basic/basic_language_model_t5.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 调用预训练的t5-chinese模型直接做预测,使用的BertTokenizer
# t5使用的是t5.1.0的结构
import
torch
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
,
load_vocab
from
bert4torch.snippets
import
AutoRegressiveDecoder
# bert配置
config_path
=
'F:/Projects/pretrain_ckpt/t5/[uer_t5_torch]--t5-small-chinese-cluecorpussmall/bert4torch_config.json'
checkpoint_path
=
'F:/Projects/pretrain_ckpt/t5/[uer_t5_torch]--t5-small-chinese-cluecorpussmall/pytorch_model.bin'
dict_path
=
'F:/Projects/pretrain_ckpt/t5/[uer_t5_torch]--t5-small-chinese-cluecorpussmall/vocab.txt'
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
# 加载并精简词表,建立分词器
token_dict
=
load_vocab
(
dict_path
=
dict_path
,
simplified
=
False
,
startswith
=
[
'[PAD]'
,
'[UNK]'
,
'[CLS]'
,
'[SEP]'
],
)
tokenizer
=
Tokenizer
(
token_dict
,
do_lower_case
=
True
)
model
=
build_transformer_model
(
config_path
,
checkpoint_path
,
model
=
't5.1.0'
,
segment_vocab_size
=
0
,
attention_scale
=
False
,
is_dropout
=
True
,
).
to
(
device
)
class
AutoTitle
(
AutoRegressiveDecoder
):
"""seq2seq解码器
"""
@
AutoRegressiveDecoder
.
wraps
(
default_rtype
=
'logits'
)
def
predict
(
self
,
inputs
,
output_ids
,
states
):
token_ids
=
inputs
[
0
]
return
model
.
predict
([[
token_ids
],
[
output_ids
]])[
-
1
][:,
-
1
,
:]
# 保留最后一位
def
generate
(
self
,
text
,
topk
=
1
,
topp
=
0.95
):
token_ids
,
_
=
tokenizer
.
encode
(
text
,
maxlen
=
256
)
output_ids
=
self
.
beam_search
([
token_ids
],
topk
=
topk
)
# 基于beam search
return
tokenizer
.
decode
(
output_ids
.
cpu
().
numpy
())
autotitle
=
AutoTitle
(
start_id
=
tokenizer
.
_token_start_id
,
end_id
=
1
,
maxlen
=
32
,
device
=
device
)
# 这里end_id可以设置为tokenizer._token_end_id这样结果更短
if
__name__
==
'__main__'
:
print
(
autotitle
.
generate
(
'中国的首都是extra0京'
))
examples/basic/basic_language_model_transformer_xl.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 调用transformer_xl模型,该模型流行度较低,未找到中文预训练模型
# last_hidden_state目前是debug到transformer包中查看,经比对和本框架一致
# 用的是transformer中的英文预训练模型来验证正确性
# 转换脚本: convert_script/convert_transformer_xl.py
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
import
torch
pretrained_model
=
"F:/Projects/pretrain_ckpt/transformer_xl/[english_hugging_face_torch]--transfo-xl-wt103"
# ----------------------transformers包----------------------
tokenizer
=
AutoTokenizer
.
from_pretrained
(
pretrained_model
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
pretrained_model
)
model
.
eval
()
inputs
=
tokenizer
(
"Hello, my dog is cute"
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
# 这里只能断点进去看
outputs
=
model
(
**
inputs
,
labels
=
inputs
[
"input_ids"
])
loss
=
outputs
.
losses
print
(
'transforms loss: '
,
loss
)
# ----------------------bert4torch配置----------------------
from
bert4torch.models
import
build_transformer_model
config_path
=
f
'
{
pretrained_model
}
/bert4torch_config.json'
checkpoint_path
=
f
'
{
pretrained_model
}
/bert4torch_pytorch_model.bin'
model
=
build_transformer_model
(
config_path
,
checkpoint_path
=
checkpoint_path
,
model
=
'transformer_xl'
,
)
print
(
'bert4torch last_hidden_state: '
,
model
.
predict
([
inputs
[
'input_ids'
]]))
# tensor([[[ 0.1027, 0.0604, -0.2585, ..., 0.3137, -0.2679, 0.1036],
# [ 0.3482, -0.0458, -0.4582, ..., 0.0242, -0.0721, 0.2311],
# [ 0.3426, -0.1353, -0.4145, ..., 0.1123, 0.1374, 0.1313],
# [ 0.0038, -0.0978, -0.5570, ..., 0.0487, -0.1891, -0.0608],
# [-0.2155, -0.1388, -0.5549, ..., -0.1458, 0.0774, 0.0419],
# [ 0.0967, -0.1781, -0.4328, ..., -0.1831, -0.0808, 0.0890]]])
\ No newline at end of file
examples/basic/basic_language_model_xlnet.py
0 → 100644
View file @
66a1d0d0
from
transformers
import
XLNetTokenizer
,
XLNetModel
import
torch
pretrained_model
=
"F:/Projects/pretrain_ckpt/xlnet/[hit_torch_base]--chinese-xlnet-base"
tokenizer
=
XLNetTokenizer
.
from_pretrained
(
pretrained_model
)
model
=
XLNetModel
.
from_pretrained
(
pretrained_model
)
inputs
=
tokenizer
([
"你好啊,我叫张三"
,
"天气不错啊"
],
padding
=
True
,
return_tensors
=
"pt"
)
outputs
=
model
(
**
inputs
)
last_hidden_states
=
outputs
.
last_hidden_state
print
(
'--------transformers last_hidden_state--------
\n
'
,
last_hidden_states
)
# ----------------------bert4torch配置----------------------
from
bert4torch.models
import
build_transformer_model
config_path
=
f
'
{
pretrained_model
}
/bert4torch_config.json'
checkpoint_path
=
f
'
{
pretrained_model
}
/pytorch_model.bin'
model
=
build_transformer_model
(
config_path
,
checkpoint_path
=
checkpoint_path
,
model
=
'xlnet'
,
# with_lm=True
token_pad_ids
=
tokenizer
.
pad_token_id
,
)
print
(
'--------bert4torch last_hidden_state--------
\n
'
,
model
.
predict
([
inputs
[
'input_ids'
],
inputs
[
'token_type_ids'
]]))
\ No newline at end of file
examples/basic/basic_make_uncased_model_cased.py
0 → 100644
View file @
66a1d0d0
#! -*- coding: utf-8 -*-
# 通过简单修改词表,使得不区分大小写的模型有区分大小写的能力
# 基本思路:将英文单词大写化后添加到词表中,并修改模型Embedding层
from
bert4torch.models
import
build_transformer_model
from
bert4torch.tokenizers
import
Tokenizer
,
load_vocab
import
torch
root_model_path
=
"F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path
=
root_model_path
+
"/vocab.txt"
config_path
=
root_model_path
+
"/bert_config.json"
checkpoint_path
=
root_model_path
+
'/pytorch_model.bin'
token_dict
=
load_vocab
(
vocab_path
)
new_token_dict
=
token_dict
.
copy
()
compound_tokens
=
[]
for
t
,
i
in
sorted
(
token_dict
.
items
(),
key
=
lambda
s
:
s
[
1
]):
# 这里主要考虑两种情况:1、首字母大写;2、整个单词大写。
# Python2下,新增了5594个token;Python3下,新增了5596个token。
tokens
=
[]
if
t
.
isalpha
():
tokens
.
extend
([
t
[:
1
].
upper
()
+
t
[
1
:],
t
.
upper
()])
elif
t
[:
2
]
==
'##'
and
t
[
2
:].
isalpha
():
tokens
.
append
(
t
.
upper
())
for
token
in
tokens
:
if
token
not
in
new_token_dict
:
compound_tokens
.
append
([
i
])
new_token_dict
[
token
]
=
len
(
new_token_dict
)
tokenizer
=
Tokenizer
(
new_token_dict
,
do_lower_case
=
False
)
model
=
build_transformer_model
(
config_path
,
checkpoint_path
,
compound_tokens
=
compound_tokens
,
# 增加新token,用旧token平均来初始化
)
text
=
u
'Welcome to BEIJING.'
tokens
=
tokenizer
.
tokenize
(
text
)
print
(
tokens
)
"""
输出:['[CLS]', u'Welcome', u'to', u'BE', u'##I', u'##JING', u'.', '[SEP]']
"""
token_ids
,
segment_ids
=
tokenizer
.
encode
(
text
)
token_ids
,
segment_ids
=
torch
.
tensor
([
token_ids
]),
torch
.
tensor
([
segment_ids
])
model
.
eval
()
with
torch
.
no_grad
():
print
(
model
([
token_ids
,
segment_ids
])[
0
])
"""
输出:
[[[-1.4999904e-01 1.9651388e-01 -1.7924258e-01 ... 7.8269649e-01
2.2241375e-01 1.1325148e-01]
[-4.5268752e-02 5.5090344e-01 7.4699545e-01 ... -4.7773960e-01
-1.7562288e-01 4.1265407e-01]
[ 7.0158571e-02 1.7816302e-01 3.6949167e-01 ... 9.6258509e-01
-8.4678203e-01 6.3776302e-01]
...
[ 9.3637377e-01 3.0232478e-02 8.1411439e-01 ... 7.9186147e-01
7.5704646e-01 -8.3475001e-04]
[ 2.3699696e-01 2.9953337e-01 8.1962071e-02 ... -1.3776925e-01
3.8681498e-01 3.2553676e-01]
[ 1.9728680e-01 7.7782705e-02 5.2951699e-01 ... 8.9622810e-02
-2.3932748e-02 6.9600858e-02]]]
"""
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment