basic_make_uncased_model_cased.py 2.4 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#! -*- coding: utf-8 -*-
# 通过简单修改词表,使得不区分大小写的模型有区分大小写的能力
# 基本思路:将英文单词大写化后添加到词表中,并修改模型Embedding层

from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer, load_vocab
import torch

root_model_path = "F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'


token_dict = load_vocab(vocab_path)
new_token_dict = token_dict.copy()
compound_tokens = []

for t, i in sorted(token_dict.items(), key=lambda s: s[1]):
    # 这里主要考虑两种情况:1、首字母大写;2、整个单词大写。
    # Python2下,新增了5594个token;Python3下,新增了5596个token。
    tokens = []
    if t.isalpha():
        tokens.extend([t[:1].upper() + t[1:], t.upper()])
    elif t[:2] == '##' and t[2:].isalpha():
        tokens.append(t.upper())
    for token in tokens:
        if token not in new_token_dict:
            compound_tokens.append([i])
            new_token_dict[token] = len(new_token_dict)

tokenizer = Tokenizer(new_token_dict, do_lower_case=False)

model = build_transformer_model(
    config_path,
    checkpoint_path,
    compound_tokens=compound_tokens,  # 增加新token,用旧token平均来初始化
)

text = u'Welcome to BEIJING.'
tokens = tokenizer.tokenize(text)
print(tokens)
"""
输出:['[CLS]', u'Welcome', u'to', u'BE', u'##I', u'##JING', u'.', '[SEP]']
"""

token_ids, segment_ids = tokenizer.encode(text)
token_ids, segment_ids = torch.tensor([token_ids]), torch.tensor([segment_ids])
model.eval()
with torch.no_grad():
  print(model([token_ids, segment_ids])[0])
"""
输出:
[[[-1.4999904e-01  1.9651388e-01 -1.7924258e-01 ...  7.8269649e-01
    2.2241375e-01  1.1325148e-01]
  [-4.5268752e-02  5.5090344e-01  7.4699545e-01 ... -4.7773960e-01
   -1.7562288e-01  4.1265407e-01]
  [ 7.0158571e-02  1.7816302e-01  3.6949167e-01 ...  9.6258509e-01
   -8.4678203e-01  6.3776302e-01]
  ...
  [ 9.3637377e-01  3.0232478e-02  8.1411439e-01 ...  7.9186147e-01
    7.5704646e-01 -8.3475001e-04]
  [ 2.3699696e-01  2.9953337e-01  8.1962071e-02 ... -1.3776925e-01
    3.8681498e-01  3.2553676e-01]
  [ 1.9728680e-01  7.7782705e-02  5.2951699e-01 ...  8.9622810e-02
   -2.3932748e-02  6.9600858e-02]]]
"""