Commit 988eb4e6 authored by zhuwenwen's avatar zhuwenwen
Browse files

update offline_streaming_inference_chat_demo.py

parent 54ddee7f
...@@ -9,103 +9,106 @@ from transformers import AutoTokenizer ...@@ -9,103 +9,106 @@ from transformers import AutoTokenizer
import logging import logging
import argparse import argparse
import sys import sys
vllm_logger = logging.getLogger("vllm")
vllm_logger.setLevel(logging.WARNING)
if __name__ == '__main__':
class FlexibleArgumentParser(argparse.ArgumentParser): vllm_logger = logging.getLogger("vllm")
"""ArgumentParser that allows both underscore and dash in names.""" vllm_logger.setLevel(logging.WARNING)
def parse_args(self, args=None, namespace=None): class FlexibleArgumentParser(argparse.ArgumentParser):
if args is None: """ArgumentParser that allows both underscore and dash in names."""
args = sys.argv[1:]
def parse_args(self, args=None, namespace=None):
# Convert underscores to dashes and vice versa in argument names if args is None:
processed_args = [] args = sys.argv[1:]
for arg in args:
if arg.startswith('--'): # Convert underscores to dashes and vice versa in argument names
if '=' in arg: processed_args = []
key, value = arg.split('=', 1) for arg in args:
key = '--' + key[len('--'):].replace('_', '-') if arg.startswith('--'):
processed_args.append(f'{key}={value}') if '=' in arg:
key, value = arg.split('=', 1)
key = '--' + key[len('--'):].replace('_', '-')
processed_args.append(f'{key}={value}')
else:
processed_args.append('--' +
arg[len('--'):].replace('_', '-'))
else: else:
processed_args.append('--' + processed_args.append(arg)
arg[len('--'):].replace('_', '-'))
else: return super().parse_args(processed_args, namespace)
processed_args.append(arg)
parser = FlexibleArgumentParser()
return super().parse_args(processed_args, namespace) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser) # chat = [
args = parser.parse_args() # {"role": "user", "content": "Hello, how are you?"},
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# chat = [ # {"role": "user", "content": "I'd like to show off how chat templating works!"},
# {"role": "user", "content": "Hello, how are you?"}, # ]
# {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
# {"role": "user", "content": "I'd like to show off how chat templating works!"}, tokenizer = AutoTokenizer.from_pretrained(args.model)
# ] # try:
# f = open(args.template,'r')
tokenizer = AutoTokenizer.from_pretrained(args.model) # tokenizer.chat_template = f.read()
# try: # except Exception as e:
# f = open(args.template,'r') # print('except:',e)
# tokenizer.chat_template = f.read() # finally:
# except Exception as e: # f.close()
# print('except:',e)
# finally:
# f.close()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序")
model_name = args.model.split("/")[-1] if args.model.split("/")[-1] !="" else args.model.split("/")[-2]
print(f"欢迎使用{model_name}模型,输入内容即可进行对话,stop 终止程序") def build_prompt(history):
prompt = ""
for query, response in history:
def build_prompt(history): prompt += f"\n\n用户:{query}"
prompt = "" prompt += f"\n\n{model_name}:{response}"
for query, response in history: return prompt
prompt += f"\n\n用户:{query}"
prompt += f"\n\n{model_name}:{response}"
return prompt history = []
while True:
query = input("\n用户:")
history = [] if query.strip() == "stop":
while True: break
query = input("\n用户:") history.append({"role": "user", "content": query})
if query.strip() == "stop": new_query = tokenizer.apply_chat_template(history, tokenize=False)
break example_input = {
history.append({"role": "user", "content": query}) "prompt": new_query,
new_query = tokenizer.apply_chat_template(history, tokenize=False) "stream": False,
example_input = { "temperature": 0.0,
"prompt": new_query, "request_id": 0,
"stream": False, }
"temperature": 0.0,
"request_id": 0, results_generator = engine.generate(
} example_input["prompt"],
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
results_generator = engine.generate( example_input["request_id"]
example_input["prompt"], )
SamplingParams(temperature=example_input["temperature"], max_tokens=100),
example_input["request_id"] start = 0
) end = 0
response = ""
start = 0 async def process_results():
end = 0 async for output in results_generator:
response = "" global end
async def process_results(): global start
async for output in results_generator: global response
global end print(output.outputs[0].text[start:], end="", flush=True)
global start length = len(output.outputs[0].text)
global response start = length
print(output.outputs[0].text[start:], end="", flush=True) response = output.outputs[0].text
length = len(output.outputs[0].text)
start = length asyncio.run(process_results())
response = output.outputs[0].text history.append({"role": "assistant", "content": response})
print()
asyncio.run(process_results())
history.append({"role": "assistant", "content": response})
print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment