Commit 988eb4e6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update offline_streaming_inference_chat_demo.py

parent 54ddee7f
...@@ -9,10 +9,13 @@ from transformers import AutoTokenizer ...@@ -9,10 +9,13 @@ from transformers import AutoTokenizer
import logging import logging
import argparse import argparse
import sys import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
if __name__ == '__main__':
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names.""" """ArgumentParser that allows both underscore and dash in names."""
def parse_args(self, args=None, namespace=None): def parse_args(self, args=None, namespace=None):
...@@ -35,36 +38,36 @@ class FlexibleArgumentParser(argparse.ArgumentParser): ...@@ -35,36 +38,36 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
return super().parse_args(processed_args, namespace) return super().parse_args(processed_args, namespace)
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
# chat = [ # chat = [
# {"role": "user", "content": "Hello, how are you?"}, # {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, # {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"}, # {"role": "user", "content": "I'd like to show off how chat templating works!"},
# ] # ]
tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer = AutoTokenizer.from_pretrained(args.model)
# try: # try:
# f = open(args.template,'r') # f = open(args.template,'r')
# tokenizer.chat_template = f.read() # tokenizer.chat_template = f.read()
# except Exception as e: # except Exception as e:
# print('except:',e) # print('except:',e)
# finally: # finally:
# f.close() # f.close()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2] model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序") print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
def build_prompt(history): def build_prompt(history):
prompt = "" prompt = ""
for query, response in history: for query, response in history:
prompt += f"\n\n用户:{query}" prompt += f"\n\n用户:{query}"
...@@ -72,8 +75,8 @@ def build_prompt(history): ...@@ -72,8 +75,8 @@ def build_prompt(history):
return prompt return prompt
history = [] history = []
while True: while True:
query = input("\n用户:") query = input("\n用户:")
if query.strip() == "stop": if query.strip() == "stop":
break break
...@@ -107,5 +110,5 @@ while True: ...@@ -107,5 +110,5 @@ while True:
asyncio.run(process_results()) asyncio.run(process_results())
history.append({"role": "assistant", "content": response}) history.append({"role": "assistant", "content": response})
print() print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment