gpt2.py 2.58 KB
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import numpy as np
from transformers import BertTokenizerFast
import onnxruntime as ort

# 加载词汇表
print("INFO: Complete loading the vocabulary")
vocab_file = os.path.join('../Resource/', 'vocab_shici.txt')
tokenizer = BertTokenizerFast(vocab_file, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")

# 加载模型
print("INFO: Parsing and compiling the model")
sess_options = ort.SessionOptions()

#设置图优化
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

#是否开启profiling
sess_options.enable_profiling = False

#加载模型
dcu_session = ort.InferenceSession("../Resource/GPT2_shici.onnx",sess_options,providers=['ROCMExecutionProvider'],)
input_name=dcu_session.get_inputs()[0].name

print('开始和GPT2对诗,输入CTRL + Z以退出')
while True:
    try:
        history = []
        text = input("user:")
        text_ids = tokenizer.encode(text, add_special_tokens=False)
        history.extend(text_ids)
        input_ids = [tokenizer.cls_token_id]
        input_ids.extend(text_ids)
        input_ids = np.array(input_ids, dtype=np.int64)
        input_ids = np.expand_dims(input_ids, axis=0)
        
        max_len = 50
        for _ in range(max_len):
               
            # 推理
            result = dcu_session.run(None, input_feed={input_name: input_ids})
            npresule =np.array(result[0])
            logits = [float(x) for x in npresule.flatten()]
            
            # 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
            logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
            
            # 排序
            score = []
            for index in range((input_ids.shape[1]-1)*22557, input_ids.shape[1]*22557):
                score.append(logits[index])
            index_and_score = sorted(enumerate(score), key=lambda x: x[1], reverse=True)
            
            # 取概率值最大的作为预测结果
            next_token = index_and_score[0][0]
            if next_token == tokenizer.convert_tokens_to_ids('[SEP]'):    # 遇到[SEP]结束标志符,结束循环
                break
            history.append(next_token)                                    # 结果存放在response列表中
            
            next_token = np.array(next_token, dtype=np.int64)
            input_ids = np.append(input_ids, next_token)
            input_ids = np.expand_dims(input_ids, axis=0)
        
        text = tokenizer.convert_ids_to_tokens(history)                   
        print("chatbot:" + "".join(text))

    except KeyboardInterrupt:
        break