basic_language_model_GAU_alpha.py 1.5 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#! -*- coding: utf-8 -*-
# 基础测试:GAU_alpha的mlm预测,和bert4keras版本比对一致
# 测试中长文本效果明显高于短文本效果
# 博客:https://kexue.fm/archives/9052
# 权重转换脚本:./convert_script/convert_GAU_alpha.py

from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch

# 加载模型,请更换成自己的路径
config_path = 'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/gau/[sushen-torch]--chinese_GAU-alpha-char_L-24_H-768/vocab.txt'


# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
model = build_transformer_model(config_path, checkpoint_path, model='gau_alpha', with_mlm='softmax')  # 建立模型,加载权重

token_ids, segments_ids = tokenizer.encode("近期正是上市公司财报密集披露的时间,但有多家龙头公司的业绩令投资者失望")
token_ids[5] = token_ids[6] = tokenizer._token_mask_id
print(''.join(tokenizer.ids_to_tokens(token_ids)))

tokens_ids_tensor = torch.tensor([token_ids])
segment_ids_tensor = torch.tensor([segments_ids])

# 需要传入参数with_mlm
model.eval()
with torch.no_grad():
    _, probas = model([tokens_ids_tensor, segment_ids_tensor])
    result = torch.argmax(probas[0, 5:7], dim=-1).numpy()
    print(tokenizer.decode(result))