Commit 62fb7f6b authored by liucong's avatar liucong
Browse files

提交基于MIGraphX的GPT2推理代码和文档

parents
# Generative Pre-Training2(GPT2)
## 模型介绍
GPT2模型:第二代生成式预训练模型(Generative Pre-Training2)。
## 模型结构
GPT2主要使用Transformer的Decoder模块为特征提取器,并对Transformer Decoder进行了一些改动,原本的Decoder包含了两个Multi-Head Attention结构,而GPT2只保留了Mask Multi-Head Attention。
## 推理
### 环境配置
在光源可拉取推理的docker镜像,GPT2模型推理的镜像如下:
```python
docker pull image.sourcefind.cn:5000/dcu/admin/base/custom:ort_dcu_1.14.0_migraphx2.5.2_dtk22.10.1
```
在光合开发者社区可下载MIGraphX安装包,python依赖安装:
```python
pip install -r requirement.txt
```
### 运行示例
我们提供了基于MIGraphX的推理脚本,版本依赖:
- Migraphx(DCU版本) >= 2.5.2
gpt2.py是基于Migraphx的推理脚本,使用需安装好MIGraphX,由于GPT2模型是动态推理的,需要设置动态shape模式,再执行推理。使用方法:
```python
# 设置动态shape模式
export MIGRAPHX_DYNAMIC_SHAPE=1
# 执行推理
python gpt2.py
```
推理结果为:
<img src="./Sample_picture.png" style="zoom:90%;" align=middle>
## 历史版本
https://developer.hpccube.com/codes/modelzoo/gpt2_migraphx
## 参考资料
https://github.com/yangjianxin1/GPT2-chitchat
https://github.com/Morizeyao/GPT2-Chinese
\ No newline at end of file
import os
import numpy as np
from transformers import BertTokenizerFast
import migraphx
# 加载词汇表
print("INFO: Complete loading the vocabulary")
vocab_file = os.path.join('../../../Resource/Models/NLP/GPT2', 'vocab_shici.txt')
tokenizer = BertTokenizerFast(vocab_file, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# 设置最大输入shape
maxInput={"input":[1,1024]}
# 加载模型
print("INFO: Parsing and compiling the model")
model = migraphx.parse_onnx("../../../Resource/Models/NLP/GPT2/GPT2_shici.onnx", map_input_dims=maxInput)
inputName=model.get_parameter_names()[0]
inputShape=model.get_parameter_shapes()[inputName].lens()
print("inputName:{0} \ninputShape:{1}".format(inputName,inputShape))
# 编译
model.compile(t=migraphx.get_target("gpu"), device_id=0)
print('开始和GPT2对诗,输入CTRL + Z以退出')
while True:
try:
history = []
text = input("user:")
text_ids = tokenizer.encode(text, add_special_tokens=False)
history.extend(text_ids)
input_ids = [tokenizer.cls_token_id]
input_ids.extend(text_ids)
input_ids = np.array(input_ids, dtype=np.int64)
input_ids = np.expand_dims(input_ids, axis=0)
max_len = 50
for _ in range(max_len):
# 执行reshape
inputShapes = [input_ids.shape[0], input_ids.shape[1]]
inputShapeMap={inputName:inputShapes}
model.reshape(inputs=inputShapeMap)
# 推理
result = model.run({inputName: migraphx.argument(input_ids)})
logits = [float(x) for x in result[0].tolist()]
# 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
# 排序
score = []
for index in range((input_ids.shape[1]-1)*22557, input_ids.shape[1]*22557):
score.append(logits[index])
index_and_score = sorted(enumerate(score), key=lambda x: x[1], reverse=True)
# 取概率值最大的作为预测结果
next_token = index_and_score[0][0]
if next_token == tokenizer.convert_tokens_to_ids('[SEP]'): # 遇到[SEP]结束标志符,结束循环
break
history.append(next_token) # 结果存放在response列表中
next_token = np.array(next_token, dtype=np.int64)
input_ids = np.append(input_ids, next_token)
input_ids = np.expand_dims(input_ids, axis=0)
text = tokenizer.convert_ids_to_tokens(history)
print("chatbot:" + "".join(text))
except KeyboardInterrupt:
break
os
numpy
transformers
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment