Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GPT2_migraphx
Commits
62fb7f6b
Commit
62fb7f6b
authored
May 22, 2023
by
liucong
Browse files
提交基于MIGraphX的GPT2推理代码和文档
parents
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
0 deletions
+127
-0
README.md
README.md
+54
-0
Sample_picture.png
Sample_picture.png
+0
-0
gpt2.py
gpt2.py
+70
-0
requirement.txt
requirement.txt
+3
-0
No files found.
README.md
0 → 100644
View file @
62fb7f6b
# Generative Pre-Training2(GPT2)
## 模型介绍
GPT2模型:第二代生成式预训练模型(Generative Pre-Training2)。
## 模型结构
GPT2主要使用Transformer的Decoder模块为特征提取器,并对Transformer Decoder进行了一些改动,原本的Decoder包含了两个Multi-Head Attention结构,而GPT2只保留了Mask Multi-Head Attention。
## 推理
### 环境配置
在光源可拉取推理的docker镜像,GPT2模型推理的镜像如下:
```
python
docker
pull
image
.
sourcefind
.
cn
:
5000
/
dcu
/
admin
/
base
/
custom
:
ort_dcu_1
.
14.0
_migraphx2
.
5.2
_dtk22
.
10.1
```
在光合开发者社区可下载MIGraphX安装包,python依赖安装:
```
python
pip
install
-
r
requirement
.
txt
```
### 运行示例
我们提供了基于MIGraphX的推理脚本,版本依赖:
-
Migraphx(DCU版本) >= 2.5.2
gpt2.py是基于Migraphx的推理脚本,使用需安装好MIGraphX,由于GPT2模型是动态推理的,需要设置动态shape模式,再执行推理。使用方法:
```
python
# 设置动态shape模式
export
MIGRAPHX_DYNAMIC_SHAPE
=
1
# 执行推理
python
gpt2
.
py
```
推理结果为:
<img
src=
"./Sample_picture.png"
style=
"zoom:90%;"
align=
middle
>
## 历史版本
https://developer.hpccube.com/codes/modelzoo/gpt2_migraphx
## 参考资料
https://github.com/yangjianxin1/GPT2-chitchat
https://github.com/Morizeyao/GPT2-Chinese
\ No newline at end of file
Sample_picture.png
0 → 100644
View file @
62fb7f6b
31.4 KB
gpt2.py
0 → 100644
View file @
62fb7f6b
import
os
import
numpy
as
np
from
transformers
import
BertTokenizerFast
import
migraphx
# 加载词汇表
print
(
"INFO: Complete loading the vocabulary"
)
vocab_file
=
os
.
path
.
join
(
'../../../Resource/Models/NLP/GPT2'
,
'vocab_shici.txt'
)
tokenizer
=
BertTokenizerFast
(
vocab_file
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
)
# 设置最大输入shape
maxInput
=
{
"input"
:[
1
,
1024
]}
# 加载模型
print
(
"INFO: Parsing and compiling the model"
)
model
=
migraphx
.
parse_onnx
(
"../../../Resource/Models/NLP/GPT2/GPT2_shici.onnx"
,
map_input_dims
=
maxInput
)
inputName
=
model
.
get_parameter_names
()[
0
]
inputShape
=
model
.
get_parameter_shapes
()[
inputName
].
lens
()
print
(
"inputName:{0}
\n
inputShape:{1}"
.
format
(
inputName
,
inputShape
))
# 编译
model
.
compile
(
t
=
migraphx
.
get_target
(
"gpu"
),
device_id
=
0
)
print
(
'开始和GPT2对诗,输入CTRL + Z以退出'
)
while
True
:
try
:
history
=
[]
text
=
input
(
"user:"
)
text_ids
=
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
history
.
extend
(
text_ids
)
input_ids
=
[
tokenizer
.
cls_token_id
]
input_ids
.
extend
(
text_ids
)
input_ids
=
np
.
array
(
input_ids
,
dtype
=
np
.
int64
)
input_ids
=
np
.
expand_dims
(
input_ids
,
axis
=
0
)
max_len
=
50
for
_
in
range
(
max_len
):
# 执行reshape
inputShapes
=
[
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]]
inputShapeMap
=
{
inputName
:
inputShapes
}
model
.
reshape
(
inputs
=
inputShapeMap
)
# 推理
result
=
model
.
run
({
inputName
:
migraphx
.
argument
(
input_ids
)})
logits
=
[
float
(
x
)
for
x
in
result
[
0
].
tolist
()]
# 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
logits
[
tokenizer
.
convert_tokens_to_ids
(
'[UNK]'
)]
=
-
float
(
'Inf'
)
# 排序
score
=
[]
for
index
in
range
((
input_ids
.
shape
[
1
]
-
1
)
*
22557
,
input_ids
.
shape
[
1
]
*
22557
):
score
.
append
(
logits
[
index
])
index_and_score
=
sorted
(
enumerate
(
score
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# 取概率值最大的作为预测结果
next_token
=
index_and_score
[
0
][
0
]
if
next_token
==
tokenizer
.
convert_tokens_to_ids
(
'[SEP]'
):
# 遇到[SEP]结束标志符,结束循环
break
history
.
append
(
next_token
)
# 结果存放在response列表中
next_token
=
np
.
array
(
next_token
,
dtype
=
np
.
int64
)
input_ids
=
np
.
append
(
input_ids
,
next_token
)
input_ids
=
np
.
expand_dims
(
input_ids
,
axis
=
0
)
text
=
tokenizer
.
convert_ids_to_tokens
(
history
)
print
(
"chatbot:"
+
""
.
join
(
text
))
except
KeyboardInterrupt
:
break
requirement.txt
0 → 100644
View file @
62fb7f6b
os
numpy
transformers
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment