basic_language_model_CDial_GPT.py 2.61 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#! -*- coding: utf-8 -*-
# 基本测试:中文GPT模型,base版本,CDial-GPT版
# 项目链接:https://github.com/thu-coai/CDial-GPT
# 参考项目:https://github.com/bojone/CDial-GPT-tf
# 权重需转换后方可加载,转换脚本见convert_script文件夹

import torch
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
from bert4torch.snippets import AutoRegressiveDecoder

config_path = 'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/gpt/[thu-coai_torch_base]--CDial-GPT-LCCC-base/bert4torch_vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = Tokenizer(dict_path, do_lower_case=True)  # 建立分词器
speakers = [tokenizer.token_to_id('[speaker1]'), tokenizer.token_to_id('[speaker2]')]

# config中设置shared_segment_embeddings=True,segment embedding用word embedding的权重生成
model = build_transformer_model(
    config_path=config_path,
    checkpoint_path=checkpoint_path,
    model='gpt',
).to(device)  # 建立模型,加载权重


class ChatBot(AutoRegressiveDecoder):
    """基于随机采样的闲聊回复
    """
    @AutoRegressiveDecoder.wraps(default_rtype='probas')
    def predict(self, inputs, output_ids, states):
        token_ids, segment_ids = inputs
        curr_segment_ids = torch.zeros_like(output_ids) + token_ids[0, -1]
        token_ids = torch.cat([token_ids, output_ids], 1)
        segment_ids = torch.cat([segment_ids, curr_segment_ids], 1)
        logits = model.predict([token_ids, segment_ids])
        return logits[:, -1, :]

    def response(self, texts, n=1, topk=5):
        token_ids = [tokenizer._token_start_id, speakers[0]]
        segment_ids = [tokenizer._token_start_id, speakers[0]]
        for i, text in enumerate(texts):
            ids = tokenizer.encode(text)[0][1:-1] + [speakers[(i + 1) % 2]]
            token_ids.extend(ids)
            segment_ids.extend([speakers[i % 2]] * len(ids))
            segment_ids[-1] = speakers[(i + 1) % 2]
        results = self.random_sample([token_ids, segment_ids], n, topk)  # 基于随机采样
        return tokenizer.decode(results[0].cpu().numpy())


chatbot  = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, device=device)

print(chatbot.response([u'别爱我没结果', u'你这样会失去我的', u'失去了又能怎样']))
"""
回复是随机的,例如:你还有我 | 那就不要爱我 | 你是不是傻 | 等等。
"""