Unverified Commit 6ae996a8 authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor argument parsing in examples (#16635)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent b590adfd
...@@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -156,16 +156,13 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
print("-" * 50) print("-" * 50)
def main(args: Namespace):
run_encode(args.model_name, args.modality, args.seed)
model_example_map = { model_example_map = {
"e5_v": run_e5_v, "e5_v": run_e5_v,
"vlm2vec": run_vlm2vec, "vlm2vec": run_vlm2vec,
} }
if __name__ == "__main__":
def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description='Demo on using vLLM for offline inference with '
'vision language models for multimodal embedding') 'vision language models for multimodal embedding')
...@@ -184,6 +181,13 @@ if __name__ == "__main__": ...@@ -184,6 +181,13 @@ if __name__ == "__main__":
type=int, type=int,
default=None, default=None,
help="Set the seed when initializing `vllm.LLM`.") help="Set the seed when initializing `vllm.LLM`.")
return parser.parse_args()
args = parser.parse_args()
def main(args: Namespace):
run_encode(args.model_name, args.modality, args.seed)
if __name__ == "__main__":
args = parse_args()
main(args) main(args)
...@@ -767,22 +767,7 @@ def run_chat(model: str, question: str, image_urls: list[str], ...@@ -767,22 +767,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
print("-" * 50) print("-" * 50)
def main(args: Namespace): def parse_args():
model = args.model_type
method = args.method
seed = args.seed
image_urls = IMAGE_URLS[:args.num_images]
if method == "generate":
run_generate(model, QUESTION, image_urls, seed)
elif method == "chat":
run_chat(model, QUESTION, image_urls, seed)
else:
raise ValueError(f"Invalid method: {method}")
if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input for text ' 'vision language models that support multi-image input for text '
...@@ -808,6 +793,24 @@ if __name__ == "__main__": ...@@ -808,6 +793,24 @@ if __name__ == "__main__":
choices=list(range(1, 13)), # 12 is the max number of images choices=list(range(1, 13)), # 12 is the max number of images
default=2, default=2,
help="Number of images to use for the demo.") help="Number of images to use for the demo.")
return parser.parse_args()
args = parser.parse_args() def main(args: Namespace):
model = args.model_type
method = args.method
seed = args.seed
image_urls = IMAGE_URLS[:args.num_images]
if method == "generate":
run_generate(model, QUESTION, image_urls, seed)
elif method == "chat":
run_chat(model, QUESTION, image_urls, seed)
else:
raise ValueError(f"Invalid method: {method}")
if __name__ == "__main__":
args = parse_args()
main(args) main(args)
...@@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]: ...@@ -58,6 +58,16 @@ def get_response(response: requests.Response) -> list[str]:
return output return output
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
return parser.parse_args()
def main(args: Namespace): def main(args: Namespace):
prompt = args.prompt prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
...@@ -82,11 +92,5 @@ def main(args: Namespace): ...@@ -82,11 +92,5 @@ def main(args: Namespace):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() args = parse_args()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
main(args) main(args)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Example for starting a Gradio OpenAI Chatbot Webserver
Start vLLM API server:
vllm serve meta-llama/Llama-2-7b-chat-hf
Start Gradio OpenAI Chatbot Webserver:
python examples/online_serving/gradio_openai_chatbot_webserver.py \
-m meta-llama/Llama-2-7b-chat-hf
Note that `pip install --upgrade gradio` is needed to run this example.
More details: https://github.com/gradio-app/gradio
If your antivirus software blocks the download of frpc for gradio,
you can install it manually by following these steps:
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
"""
import argparse import argparse
import gradio as gr import gradio as gr
from openai import OpenAI from openai import OpenAI
# Argument parser setup
parser = argparse.ArgumentParser( def create_openai_client(api_key, base_url):
description='Chatbot Interface with Customizable Parameters') return OpenAI(api_key=api_key, base_url=base_url)
parser.add_argument('--model-url',
type=str,
default='http://localhost:8000/v1', def format_history_to_openai(history):
help='Model URL')
parser.add_argument('-m',
'--model',
type=str,
required=True,
help='Model name for the chatbot')
parser.add_argument('--temp',
type=float,
default=0.8,
help='Temperature for text generation')
parser.add_argument('--stop-token-ids',
type=str,
default='',
help='Comma-separated stop token IDs')
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
# Parse the arguments
args = parser.parse_args()
# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = args.model_url
# Create an OpenAI client to interact with the API server
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
def predict(message, history):
# Convert chat history to OpenAI format
history_openai_format = [{ history_openai_format = [{
"role": "system", "role": "system",
"content": "You are a great ai assistant." "content": "You are a great AI assistant."
}] }]
for human, assistant in history: for human, assistant in history:
history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "user", "content": human})
...@@ -54,31 +38,92 @@ def predict(message, history): ...@@ -54,31 +38,92 @@ def predict(message, history):
"role": "assistant", "role": "assistant",
"content": assistant "content": assistant
}) })
return history_openai_format
def predict(message, history, client, model_name, temp, stop_token_ids):
# Format history to OpenAI chat format
history_openai_format = format_history_to_openai(history)
history_openai_format.append({"role": "user", "content": message}) history_openai_format.append({"role": "user", "content": message})
# Create a chat completion request and send it to the API server # Send request to OpenAI API (vLLM server)
stream = client.chat.completions.create( stream = client.chat.completions.create(
model=args.model, # Model name to use model=model_name,
messages=history_openai_format, # Chat history messages=history_openai_format,
temperature=args.temp, # Temperature for text generation temperature=temp,
stream=True, # Stream response stream=True,
extra_body={ extra_body={
'repetition_penalty': 'repetition_penalty':
1, 1,
'stop_token_ids': [ 'stop_token_ids':
int(id.strip()) for id in args.stop_token_ids.split(',') [int(id.strip())
if id.strip() for id in stop_token_ids.split(',')] if stop_token_ids else []
] if args.stop_token_ids else []
}) })
# Read and return generated text from response stream # Collect all chunks and concatenate them into a full message
partial_message = "" full_message = ""
for chunk in stream: for chunk in stream:
partial_message += (chunk.choices[0].delta.content or "") full_message += (chunk.choices[0].delta.content or "")
yield partial_message
# Return the full message as a single response
return full_message
def parse_args():
parser = argparse.ArgumentParser(
description='Chatbot Interface with Customizable Parameters')
parser.add_argument('--model-url',
type=str,
default='http://localhost:8000/v1',
help='Model URL')
parser.add_argument('-m',
'--model',
type=str,
required=True,
help='Model name for the chatbot')
parser.add_argument('--temp',
type=float,
default=0.8,
help='Temperature for text generation')
parser.add_argument('--stop-token-ids',
type=str,
default='',
help='Comma-separated stop token IDs')
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
return parser.parse_args()
def build_gradio_interface(client, model_name, temp, stop_token_ids):
def chat_predict(message, history):
return predict(message, history, client, model_name, temp,
stop_token_ids)
return gr.ChatInterface(fn=chat_predict,
title="Chatbot Interface",
description="A simple chatbot powered by vLLM")
def main():
# Parse the arguments
args = parse_args()
# Set OpenAI's API key and API base to use vLLM's API server
openai_api_key = "EMPTY"
openai_api_base = args.model_url
# Create an OpenAI client
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
# Define the Gradio chatbot interface using the predict function
gradio_interface = build_gradio_interface(client, args.model, args.temp,
args.stop_token_ids)
gradio_interface.queue().launch(server_name=args.host,
server_port=args.port,
share=True)
# Create and launch a chat interface with Gradio if __name__ == "__main__":
gr.ChatInterface(predict).queue().launch(server_name=args.host, main()
server_port=args.port,
share=True)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Example for starting a Gradio Webserver
Start vLLM API server:
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-chat-hf
Start Webserver:
python examples/online_serving/gradio_webserver.py
Note that `pip install --upgrade gradio` is needed to run this example.
More details: https://github.com/gradio-app/gradio
If your antivirus software blocks the download of frpc for gradio,
you can install it manually by following these steps:
1. Download this file: https://cdn-media.huggingface.co/frpc-gradio-0.3/frpc_linux_amd64
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
"""
import argparse import argparse
import json import json
...@@ -39,16 +56,23 @@ def build_demo(): ...@@ -39,16 +56,23 @@ def build_demo():
return demo return demo
if __name__ == "__main__": def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", parser.add_argument("--model-url",
type=str, type=str,
default="http://localhost:8000/generate") default="http://localhost:8000/generate")
args = parser.parse_args() return parser.parse_args()
def main(args):
demo = build_demo() demo = build_demo()
demo.queue().launch(server_name=args.host, demo.queue().launch(server_name=args.host,
server_port=args.port, server_port=args.port,
share=True) share=True)
if __name__ == "__main__":
args = parse_args()
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment