Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
GPT2_migraphx
Commits
741ac4ae
Commit
741ac4ae
authored
May 29, 2023
by
liucong
Browse files
删除部分代码和文档
parent
816b3d52
Pipeline
#297
failed with stages
in 0 seconds
Changes
4
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
22630 deletions
+0
-22630
Sample_picture.png
Sample_picture.png
+0
-0
gpt2.py
gpt2.py
+0
-70
model/vocab_shici.txt
model/vocab_shici.txt
+0
-22557
requirement.txt
requirement.txt
+0
-3
No files found.
Sample_picture.png
deleted
100644 → 0
View file @
816b3d52
31.4 KB
gpt2.py
deleted
100644 → 0
View file @
816b3d52
import
os
import
numpy
as
np
from
transformers
import
BertTokenizerFast
import
migraphx
# 加载词汇表
print
(
"INFO: Complete loading the vocabulary"
)
vocab_file
=
os
.
path
.
join
(
'./model'
,
'vocab_shici.txt'
)
tokenizer
=
BertTokenizerFast
(
vocab_file
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
)
# 设置最大输入shape
maxInput
=
{
"input"
:[
1
,
1024
]}
# 加载模型
print
(
"INFO: Parsing and compiling the model"
)
model
=
migraphx
.
parse_onnx
(
"./model/GPT2_shici.onnx"
,
map_input_dims
=
maxInput
)
inputName
=
model
.
get_parameter_names
()[
0
]
inputShape
=
model
.
get_parameter_shapes
()[
inputName
].
lens
()
print
(
"inputName:{0}
\n
inputShape:{1}"
.
format
(
inputName
,
inputShape
))
# 编译
model
.
compile
(
t
=
migraphx
.
get_target
(
"gpu"
),
device_id
=
0
)
print
(
'开始和GPT2对诗,输入CTRL + Z以退出'
)
while
True
:
try
:
history
=
[]
text
=
input
(
"user:"
)
text_ids
=
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
history
.
extend
(
text_ids
)
input_ids
=
[
tokenizer
.
cls_token_id
]
input_ids
.
extend
(
text_ids
)
input_ids
=
np
.
array
(
input_ids
,
dtype
=
np
.
int64
)
input_ids
=
np
.
expand_dims
(
input_ids
,
axis
=
0
)
max_len
=
50
for
_
in
range
(
max_len
):
# 执行reshape
inputShapes
=
[
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]]
inputShapeMap
=
{
inputName
:
inputShapes
}
model
.
reshape
(
inputs
=
inputShapeMap
)
# 推理
result
=
model
.
run
({
inputName
:
migraphx
.
argument
(
input_ids
)})
logits
=
[
float
(
x
)
for
x
in
result
[
0
].
tolist
()]
# 对于[UNK]的概率设为无穷小,模型的预测结果不可能是[UNK]
logits
[
tokenizer
.
convert_tokens_to_ids
(
'[UNK]'
)]
=
-
float
(
'Inf'
)
# 排序
score
=
[]
for
index
in
range
((
input_ids
.
shape
[
1
]
-
1
)
*
22557
,
input_ids
.
shape
[
1
]
*
22557
):
score
.
append
(
logits
[
index
])
index_and_score
=
sorted
(
enumerate
(
score
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
# 取概率值最大的作为预测结果
next_token
=
index_and_score
[
0
][
0
]
if
next_token
==
tokenizer
.
convert_tokens_to_ids
(
'[SEP]'
):
# 遇到[SEP]结束标志符,结束循环
break
history
.
append
(
next_token
)
# 结果存放在response列表中
next_token
=
np
.
array
(
next_token
,
dtype
=
np
.
int64
)
input_ids
=
np
.
append
(
input_ids
,
next_token
)
input_ids
=
np
.
expand_dims
(
input_ids
,
axis
=
0
)
text
=
tokenizer
.
convert_ids_to_tokens
(
history
)
print
(
"chatbot:"
+
""
.
join
(
text
))
except
KeyboardInterrupt
:
break
model/vocab_shici.txt
deleted
100644 → 0
View file @
816b3d52
This diff is collapsed.
Click to expand it.
requirement.txt
deleted
100644 → 0
View file @
816b3d52
os
numpy
transformers
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment