Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
FinGPT-glm_pytorch
Commits
7c61eced
Commit
7c61eced
authored
May 10, 2024
by
wanglch
Browse files
Upload New File
parent
127c7a76
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
0 deletions
+68
-0
infernece_FinGPT.py
infernece_FinGPT.py
+68
-0
No files found.
infernece_FinGPT.py
0 → 100644
View file @
7c61eced
import
os
import
platform
import
signal
from
transformers
import
AutoTokenizer
,
AutoModel
import
readline
from
peft
import
PeftModel
base_model
=
'/FinGPT/FinGPT_mt_chatglm2-6b-merged'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
base_model
,
trust_remote_code
=
True
)
model
=
AutoModel
.
from_pretrained
(
base_model
,
trust_remote_code
=
True
,
device_map
=
"auto"
)
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2)
model
=
model
.
eval
()
os_name
=
platform
.
system
()
clear_command
=
'cls'
if
os_name
==
'Windows'
else
'clear'
stop_stream
=
False
def
build_prompt
(
history
):
prompt
=
"欢迎使用 FinGPT-ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
for
query
,
response
in
history
:
prompt
+=
f
"
\n\n
用户:
{
query
}
"
prompt
+=
f
"
\n\n
FinGPT:
{
response
}
"
return
prompt
def
signal_handler
(
signal
,
frame
):
global
stop_stream
stop_stream
=
True
def
main
():
os
.
system
(
clear_command
)
past_key_values
,
history
=
None
,
[]
global
stop_stream
print
(
"欢迎使用由中科曙光智能与计算产业事业部开发的FinGPT-ChatGLM2-6B金融大模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
)
while
True
:
query
=
input
(
"
\n
用户:"
)
if
query
.
strip
()
==
"stop"
:
break
if
query
.
strip
()
==
"clear"
:
past_key_values
,
history
=
None
,
[]
os
.
system
(
clear_command
)
print
(
"欢迎使用由中科曙光智能与计算产业事业部开发的FinGPT-ChatGLM2-6B金融大模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
)
continue
print
(
"
\n
FinGPT:"
,
end
=
""
)
current_length
=
0
for
response
,
history
,
past_key_values
in
model
.
stream_chat
(
tokenizer
,
query
,
history
=
history
,
past_key_values
=
past_key_values
,
return_past_key_values
=
True
,
temperature
=
0.8
,
):
if
stop_stream
:
stop_stream
=
False
break
else
:
print
(
response
[
current_length
:],
end
=
""
,
flush
=
True
)
current_length
=
len
(
response
)
print
(
""
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment