Unverified Commit d74132ca authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

fix offline inference chat response prompt (#32088)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent a34abc49
...@@ -9,7 +9,7 @@ Usage: ...@@ -9,7 +9,7 @@ Usage:
python examples/offline_inference/context_extension.py python examples/offline_inference/context_extension.py
""" """
from vllm import LLM, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
def create_llm(): def create_llm():
...@@ -45,13 +45,15 @@ def run_llm_chat(llm): ...@@ -45,13 +45,15 @@ def run_llm_chat(llm):
{"role": "assistant", "content": "Hello! How can I assist you today?"}, {"role": "assistant", "content": "Hello! How can I assist you today?"},
] ]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False) outputs = llm.chat(conversation, sampling_params, use_tqdm=False)
return outputs return outputs, [
conversation,
]
def print_outputs(outputs): def print_outputs(outputs: list[RequestOutput], conversations: list):
print("\nGenerated Outputs:\n" + "-" * 80) print("\nGenerated Outputs:\n" + "-" * 80)
for output in outputs: for i, output in enumerate(outputs):
prompt = output.prompt prompt = conversations[i]
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n") print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}") print(f"Generated text: {generated_text!r}")
...@@ -60,8 +62,8 @@ def print_outputs(outputs): ...@@ -60,8 +62,8 @@ def print_outputs(outputs):
def main(): def main():
llm = create_llm() llm = create_llm()
outputs = run_llm_chat(llm) outputs, conversations = run_llm_chat(llm)
print_outputs(outputs) print_outputs(outputs, conversations)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -152,9 +152,12 @@ def main(args): ...@@ -152,9 +152,12 @@ def main(args):
# print the generated text # print the generated text
if args.print_output: if args.print_output:
for output in outputs: for i, output in enumerate(outputs):
print("-" * 50) print("-" * 50)
print(f"prompt: {output.prompt}") if not args.custom_mm_prompts:
print(f"prompt: {prompts[i].prompt}")
else:
print(f"prompt: {prompts[i]}")
print(f"generated text: {output.outputs[0].text}") print(f"generated text: {output.outputs[0].text}")
print("-" * 50) print("-" * 50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment