Commit 741ac4ae authored by liucong's avatar liucong
Browse files

删除部分代码和文档

parent 816b3d52
Pipeline #297 failed with stages
in 0 seconds
import os
import numpy as np
from transformers import BertTokenizerFast
import migraphx
# 加载词汇表
print("INFO: Complete loading the vocabulary")
vocab_file = os.path.join('./model', 'vocab_shici.txt')
tokenizer = BertTokenizerFast(vocab_file, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# 设置最大输入shape
maxInput={"input":[1,1024]}
# 加载模型
print("INFO: Parsing and compiling the model")
model = migraphx.parse_onnx("./model/GPT2_shici.onnx", map_input_dims=maxInput)
inputName=model.get_parameter_names()[0]
inputShape=model.get_parameter_shapes()[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
# 编译
model.compile(t=migraphx.get_target("gpu"), device_id=0)
print('开始和GPT2对诗,输入CTRL + Z以退出')
while True:
try:
history = []
text = input("user:")
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.extend(text_ids)
input_ids = [tokenizer.cls_token_id]
input_ids.extend(text_ids)
input_ids = np.array(input_ids, dtype=np.int64)
input_ids = np.expand_dims(input_ids, axis=0)
max_len = 50
for _ in range(max_len):
# 执行reshape
inputShapes = [input_ids.shape[0], input_ids.shape[1]]
inputShapeMap={inputName:inputShapes}
model.reshape(inputs=inputShapeMap)
# 推理
result = model.run({inputName: migraphx.argument(input_ids)})
logits = [float(x) for x in result[0].tolist()]
# 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
# 排序
score = []
for index in range((input_ids.shape[1]-1)*22557, input_ids.shape[1]*22557):
score.append(logits[index])
index_and_score = sorted(enumerate(score), key=lambda x: x[1], reverse=True)
# 取概率值最大的作为预测结果
next_token = index_and_score[0][0]
if next_token == tokenizer.convert_tokens_to_ids('[SEP]'): # 遇到[SEP]结束标志符,结束循环
break
history.append(next_token) # 结果存放在response列表中
next_token = np.array(next_token, dtype=np.int64)
input_ids = np.append(input_ids, next_token)
input_ids = np.expand_dims(input_ids, axis=0)
text = tokenizer.convert_ids_to_tokens(history)
print("chatbot:" + "".join(text))
except KeyboardInterrupt:
break
This diff is collapsed.
os
numpy
transformers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment