Commit 988eb4e6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update offline_streaming_inference_chat_demo.py

parent 54ddee7f
......@@ -9,10 +9,13 @@ from transformers import AutoTokenizer
import logging
import argparse
import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
if __name__ == '__main__':
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None):
......@@ -35,36 +38,36 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
# chat = [
# {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ]
tokenizer = AutoTokenizer.from_pretrained(args.model)
# try:
# f = open(args.template,'r')
# tokenizer.chat_template = f.read()
# except Exception as e:
# print('except:',e)
# finally:
# f.close()
tokenizer = AutoTokenizer.from_pretrained(args.model)
# try:
# f = open(args.template,'r')
# tokenizer.chat_template = f.read()
# except Exception as e:
# print('except:',e)
# finally:
# f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history):
def build_prompt(history):
prompt = ""
for query, response in history:
prompt += f"\n\n用户:{query}"
......@@ -72,8 +75,8 @@ def build_prompt(history):
return prompt
history = []
while True:
history = []
while True:
query = input("\n用户:")
if query.strip() == "stop":
break
......@@ -107,5 +110,5 @@ while True:
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment