Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
988eb4e6
Commit
988eb4e6
authored
Dec 13, 2024
by
zhuwenwen
Browse files
update offline_streaming_inference_chat_demo.py
parent
54ddee7f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
101 additions
and
98 deletions
+101
-98
examples/offline_streaming_inference_chat_demo.py
examples/offline_streaming_inference_chat_demo.py
+101
-98
No files found.
examples/offline_streaming_inference_chat_demo.py
View file @
988eb4e6
...
...
@@ -9,10 +9,13 @@ from transformers import AutoTokenizer
import
logging
import
argparse
import
sys
vllm_logger
=
logging
.
getLogger
(
"vllm"
)
vllm_logger
.
setLevel
(
logging
.
WARNING
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
if
__name__
==
'__main__'
:
vllm_logger
=
logging
.
getLogger
(
"vllm"
)
vllm_logger
.
setLevel
(
logging
.
WARNING
)
class
FlexibleArgumentParser
(
argparse
.
ArgumentParser
):
"""ArgumentParser that allows both underscore and dash in names."""
def
parse_args
(
self
,
args
=
None
,
namespace
=
None
):
...
...
@@ -35,36 +38,36 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return
super
().
parse_args
(
processed_args
,
namespace
)
parser
=
FlexibleArgumentParser
()
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
parser
=
FlexibleArgumentParser
()
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model
)
# try:
# f = open(args.template,'r')
# tokenizer.chat_template = f.read()
# except Exception as e:
# print('except:',e)
# finally:
# f.close()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model
)
# try:
# f = open(args.template,'r')
# tokenizer.chat_template = f.read()
# except Exception as e:
# print('except:',e)
# finally:
# f.close()
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
engine_args
=
AsyncEngineArgs
.
from_cli_args
(
args
)
engine
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
)
model_name
=
args
.
model
.
split
(
"/"
)[
-
1
]
if
args
.
model
.
split
(
"/"
)[
-
1
]
!=
""
else
args
.
model
.
split
(
"/"
)[
-
2
]
print
(
f
"欢迎使用
{
model_name
}
模型,输入内容即可进行对话,stop 终止程序"
)
model_name
=
args
.
model
.
split
(
"/"
)[
-
1
]
if
args
.
model
.
split
(
"/"
)[
-
1
]
!=
""
else
args
.
model
.
split
(
"/"
)[
-
2
]
print
(
f
"欢迎使用
{
model_name
}
模型,输入内容即可进行对话,stop 终止程序"
)
def
build_prompt
(
history
):
def
build_prompt
(
history
):
prompt
=
""
for
query
,
response
in
history
:
prompt
+=
f
"
\n\n
用户:
{
query
}
"
...
...
@@ -72,8 +75,8 @@ def build_prompt(history):
return
prompt
history
=
[]
while
True
:
history
=
[]
while
True
:
query
=
input
(
"
\n
用户:"
)
if
query
.
strip
()
==
"stop"
:
break
...
...
@@ -107,5 +110,5 @@ while True:
asyncio
.
run
(
process_results
())
history
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
print
()
print
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment