Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
......@@ -45,8 +45,7 @@ if dist.get_rank() == 0:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n"
f"Generated text: {generated_text!r}\n")
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
print("-" * 50)
"""
Further tips:
......
......@@ -20,10 +20,12 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
def main():
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4,
max_model_len=128)
llm = LLM(
model="Qwen/Qwen2-1.5B-Instruct",
max_num_batched_tokens=64,
max_num_seqs=4,
max_model_len=128,
)
outputs = llm.generate(prompts, sampling_params)
print("-" * 50)
for output, answer in zip(outputs, answers):
......
......@@ -6,6 +6,7 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
import os
import random
from contextlib import contextmanager
......@@ -49,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n")
for question in questions]
prompts = [
(
f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n"
)
for question in questions
]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
......@@ -135,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
)
prompts = [
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:"
for question in questions
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" for question in questions
]
return ModelRequestData(
......@@ -198,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
prompts = [("<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n") for question in questions]
prompts = [
(
"<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
......@@ -225,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>" for question in questions
{question}<|assistant|>"
for question in questions
]
stop_token_ids = [151329, 151336, 151338]
......@@ -250,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
......@@ -284,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={
"size": {
"longest_edge": 3 * 364
},
"size": {"longest_edge": 3 * 364},
},
limit_mm_per_prompt={modality: 1},
)
prompts = [(
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:"
) for question in questions]
prompts = [
(f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:")
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
......@@ -311,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2,
enforce_eager=True,
mm_processor_kwargs={
"max_image_size": {
"longest_edge": 384
},
"max_image_size": {"longest_edge": 384},
},
limit_mm_per_prompt={modality: 1},
)
......@@ -330,7 +335,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
# InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "OpenGVLab/InternVL3-2B"
engine_args = EngineArgs(
......@@ -345,15 +349,14 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video":
placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"{placeholder}\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for InternVL
# models variants may have different stop tokens
......@@ -361,9 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [
token_id for token_id in stop_token_ids if token_id is not None
]
stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
return ModelRequestData(
engine_args=engine_args,
......@@ -379,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
prompts = [
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
f"<|media_pad|><|media_end|>{question}<|im_end|>"
"<|im_assistant|>assistant<|im_middle|>" for question in questions
"<|im_assistant|>assistant<|im_middle|>"
for question in questions
]
engine_args = EngineArgs(
......@@ -399,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
]
prompts = [f"USER: <image>\n{question}\nASSISTANT:" for question in questions]
engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf",
......@@ -434,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
# LlaVA-NeXT-Video
# Currently only support for video input
def run_llava_next_video(questions: list[str],
modality: str) -> ModelRequestData:
def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "video"
prompts = [
f"USER: <video>\n{question} ASSISTANT:" for question in questions
]
prompts = [f"USER: <video>\n{question} ASSISTANT:" for question in questions]
engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192,
......@@ -455,19 +452,19 @@ def run_llava_next_video(questions: list[str],
# LLaVA-OneVision
def run_llava_onevision(questions: list[str],
modality: str) -> ModelRequestData:
def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
if modality == "video":
prompts = [
f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
<|im_start|>assistant\n"
for question in questions
]
elif modality == "image":
prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
<|im_start|>assistant\n"
for question in questions
]
engine_args = EngineArgs(
......@@ -486,11 +483,8 @@ def run_llava_onevision(questions: list[str],
def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501
prompts = [
llama3_template.format(f"{question}\n<image>")
for question in questions
]
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" # noqa: E501
prompts = [llama3_template.format(f"{question}\n<image>") for question in questions]
engine_args = EngineArgs(
model="TIGER-Lab/Mantis-8B-siglip-llama3",
......@@ -530,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# 2.6: image, video
# o2.6: image, video, audio
# model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
......@@ -547,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
# 2.6 / o2.6
stop_tokens = ['<|im_end|>', '<|endoftext|>']
stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
modality_placeholder = {
......@@ -557,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
prompts = [
tokenizer.apply_chat_template(
[{
'role': 'user',
'content': f"{modality_placeholder[modality]}\n{question}"
}],
[
{
"role": "user",
"content": f"{modality_placeholder[modality]}\n{question}",
}
],
tokenize=False,
add_generation_prompt=True) for question in questions
add_generation_prompt=True,
)
for question in questions
]
return ModelRequestData(
......@@ -622,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{
"role":
"user",
"content": [{
"type": "image"
}, {
"type": "text",
"text": question
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
messages = [
[
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": question}],
}
]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
return ModelRequestData(
engine_args=engine_args,
......@@ -657,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{
"role":
"user",
"content": [{
"type": "image"
}, {
"type": "text",
"text": f"{question}"
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
messages = [
[
{
"role": "user",
"content": [{"type": "image"}, {"type": "text", "text": f"{question}"}],
}
]
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False
)
stop_token_ids = None
return ModelRequestData(
engine_args=engine_args,
......@@ -693,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions
<|im_start|>assistant\n"
for question in questions
]
return ModelRequestData(
......@@ -717,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -748,15 +742,13 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -847,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
# we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora")
prompts = [
f"<|user|><|image_1|>{question}<|end|><|assistant|>"
for question in questions
f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions
]
engine_args = EngineArgs(
model=model_path,
......@@ -915,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2-VL
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct"
engine_args = EngineArgs(
......@@ -936,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>"
prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
......@@ -950,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2.5-VL
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
engine_args = EngineArgs(
......@@ -971,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>"
prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions
(
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
......@@ -1007,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")
"generating text and speech."
)
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions]
prompts = [
(
f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
......@@ -1032,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1},
)
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
messages = [[{
'role': 'user',
'content': f"<image>\n{question}"
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"<image>\n{question}"}] for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for SkyworkR1V
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
......@@ -1104,8 +1103,7 @@ def get_multi_modal_input(args):
"""
if args.modality == "image":
# Input image and question
image = convert_image_mode(
ImageAsset("cherry_blossom").pil_image, "RGB")
image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [
"What is the content of this image?",
"Describe the content of this image in detail.",
......@@ -1120,8 +1118,7 @@ def get_multi_modal_input(args):
if args.modality == "video":
# Input video and question
video = VideoAsset(name="baby_reading",
num_frames=args.num_frames).np_ndarrays
video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
vid_questions = ["Why is this video funny?"]
return {
......@@ -1133,12 +1130,13 @@ def get_multi_modal_input(args):
raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data,
prompts: list[str], modality):
"""Repeats images with provided probability of "image_repeat_prob".
def apply_image_repeat(
image_repeat_prob, num_prompts, data, prompts: list[str], modality
):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache.
"""
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0)
assert image_repeat_prob <= 1.0 and image_repeat_prob >= 0
no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob]
......@@ -1153,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val)
inputs.append({
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {
modality: cur_image
inputs.append(
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: cur_image},
}
})
)
return inputs
......@@ -1167,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
def time_counter(enable: bool):
if enable:
import time
start_time = time.time()
yield
elapsed_time = time.time() - start_time
......@@ -1179,54 +1178,65 @@ def time_counter(enable: bool):
def parse_args():
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models for text generation')
parser.add_argument('--model-type',
'-m',
type=str,
default="llava",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=4,
help='Number of prompts to run.')
parser.add_argument('--modality',
type=str,
default="image",
choices=['image', 'video'],
help='Modality of the input.')
parser.add_argument('--num-frames',
type=int,
default=16,
help='Number of frames to extract from the video.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
description="Demo on using vLLM for offline inference with "
"vision language models for text generation"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="llava",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--num-prompts", type=int, default=4, help="Number of prompts to run."
)
parser.add_argument(
"--modality",
type=str,
default="image",
choices=["image", "video"],
help="Modality of the input.",
)
parser.add_argument(
"--num-frames",
type=int,
default=16,
help="Number of frames to extract from the video.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(
'--image-repeat-prob',
"--image-repeat-prob",
type=float,
default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache'
' (if enabled)')
help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)",
)
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If True, disables caching of multi-modal preprocessor/mapper.')
"--disable-mm-preprocessor-cache",
action="store_true",
help="If True, disables caching of multi-modal preprocessor/mapper.",
)
parser.add_argument(
'--time-generate',
action='store_true',
help='If True, then print the total generate() call time')
"--time-generate",
action="store_true",
help="If True, then print the total generate() call time",
)
parser.add_argument(
'--use-different-prompt-per-request',
action='store_true',
help='If True, then use different prompt (with the same multi-modal '
'data) for each request.')
"--use-different-prompt-per-request",
action="store_true",
help="If True, then use different prompt (with the same multi-modal "
"data) for each request.",
)
return parser.parse_args()
......@@ -1245,7 +1255,8 @@ def main(args):
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {
"seed": args.seed,
......@@ -1254,44 +1265,46 @@ def main(args):
llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`.
prompts = req_data.prompts if args.use_different_prompt_per_request else [
req_data.prompts[0]
]
prompts = (
req_data.prompts
if args.use_different_prompt_per_request
else [req_data.prompts[0]]
)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=req_data.stop_token_ids)
sampling_params = SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompts[0],
"multi_modal_data": {
modality: data
},
"multi_modal_data": {modality: data},
}
else:
# Batch inference
if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob,
args.num_prompts, data, prompts,
modality)
inputs = apply_image_repeat(
args.image_repeat_prob, args.num_prompts, data, prompts, modality
)
else:
# Use the same image for all prompts
inputs = [{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {
modality: data
},
} for i in range(args.num_prompts)]
inputs = [
{
"prompt": prompts[i % len(prompts)],
"multi_modal_data": {modality: data},
}
for i in range(args.num_prompts)
]
# Add LoRA request if applicable
lora_request = (req_data.lora_requests *
args.num_prompts if req_data.lora_requests else None)
lora_request = (
req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
with time_counter(args.time_generate):
outputs = llm.generate(
......
......@@ -6,6 +6,7 @@ the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from argparse import Namespace
from dataclasses import asdict
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
......@@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple):
def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501
llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
if query["modality"] == "text":
text = query["text"]
prompt = llama3_template.format(
f"{text}\nSummary above sentence in one word: ")
prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
image = None
elif query["modality"] == "image":
prompt = llama3_template.format(
"<image>\nSummary above image in one word: ")
prompt = llama3_template.format("<image>\nSummary above image in one word: ")
image = query["image"]
else:
modality = query['modality']
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs(
......@@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
image = query["image"]
elif query["modality"] == "text+image":
text = query["text"]
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
prompt = (
f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
)
image = query["image"]
else:
modality = query['modality']
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs(
......@@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
......@@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
if req_data.image is not None:
mm_data["image"] = req_data.image
outputs = llm.embed({
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
})
outputs = llm.embed(
{
"prompt": req_data.prompt,
"multi_modal_data": mm_data,
}
)
print("-" * 50)
for output in outputs:
......@@ -164,23 +168,30 @@ model_example_map = {
def parse_args():
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models for multimodal embedding')
parser.add_argument('--model-name',
'-m',
type=str,
default="vlm2vec",
choices=model_example_map.keys(),
help='The name of the embedding model.')
parser.add_argument('--modality',
type=str,
default="image",
choices=get_args(QueryModality),
help='Modality of the input.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
description="Demo on using vLLM for offline inference with "
"vision language models for multimodal embedding"
)
parser.add_argument(
"--model-name",
"-m",
type=str,
default="vlm2vec",
choices=model_example_map.keys(),
help="The name of the embedding model.",
)
parser.add_argument(
"--modality",
type=str,
default="image",
choices=get_args(QueryModality),
help="Modality of the input.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args()
......
......@@ -4,6 +4,7 @@ This example shows how to use vLLM for running offline inference with
multi-image input on vision language models for text generation,
using the chat template defined by the model.
"""
import os
from argparse import Namespace
from dataclasses import asdict
......@@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
prompt = (
f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n<|im_start|>assistant\n"
)
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData(
......@@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_deepseek_vl2(question: str,
image_urls: list[str]) -> ModelRequestData:
def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny"
engine_args = EngineArgs(
......@@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholder = "".join(f"image_{i}:<image>\n"
for i, _ in enumerate(image_urls, start=1))
placeholder = "".join(
f"image_{i}:<image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
return ModelRequestData(
......@@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m
......@@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={
"size": {
"longest_edge": 2 * 364
},
"size": {"longest_edge": 2 * 364},
},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData(
engine_args=engine_args,
......@@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={
"max_image_size": {
"longest_edge": 384
},
"max_image_size": {"longest_edge": 384},
},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = (
f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
......@@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for InternVL
# models variants may have different stop tokens
......@@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name,
trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -449,15 +442,15 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "\n".join(f"Image-{i}: <image>\n"
for i, _ in enumerate(image_urls, start=1))
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
......@@ -509,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"num_crops": 4},
)
placeholders = "\n".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
placeholders = "\n".join(
f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
return ModelRequestData(
......@@ -542,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"dynamic_hd": 4},
)
placeholders = "".join(f"<|image_{i}|>"
for i, _ in enumerate(image_urls, start=1))
placeholders = "".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
return ModelRequestData(
......@@ -554,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_qwen_vl_chat(question: str,
image_urls: list[str]) -> ModelRequestData:
def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat"
engine_args = EngineArgs(
model=model_name,
......@@ -565,24 +557,26 @@ def load_qwen_vl_chat(question: str,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = "".join(f"Picture {i}: <img></img>\n"
for i, _ in enumerate(image_urls, start=1))
placeholders = "".join(
f"Picture {i}: <img></img>\n" for i, _ in enumerate(image_urls, start=1)
)
# This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=chat_template,
)
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
......@@ -600,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
print(
"WARNING: `qwen-vl-utils` not installed, input images will not "
"be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct"
......@@ -616,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
......@@ -653,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try:
from qwen_vl_utils import process_vision_info
except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not '
'be automatically resized. You can enable this functionality by '
'`pip install qwen-vl-utils`.')
print(
"WARNING: `qwen-vl-utils` not installed, input images will not "
"be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None
model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
......@@ -668,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls]
else:
image_data, _ = process_vision_info(messages,
return_video_kwargs=False)
image_data, _ = process_vision_info(messages, return_video_kwargs=False)
return ModelRequestData(
engine_args=engine_args,
......@@ -726,23 +715,20 @@ model_example_map = {
}
def run_generate(model, question: str, image_urls: list[str],
seed: Optional[int]):
def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]):
req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=256,
stop_token_ids=req_data.stop_token_ids)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
outputs = llm.generate(
{
"prompt": req_data.prompt,
"multi_modal_data": {
"image": req_data.image_data
},
"multi_modal_data": {"image": req_data.image_data},
},
sampling_params=sampling_params,
lora_request=req_data.lora_requests,
......@@ -755,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str],
print("-" * 50)
def run_chat(model: str, question: str, image_urls: list[str],
seed: Optional[int]):
def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]):
req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {})
req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0,
max_tokens=256,
stop_token_ids=req_data.stop_token_ids)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
outputs = llm.chat(
[{
"role":
"user",
"content": [
{
"type": "text",
"text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
[
{
"role": "user",
"content": [
{
"type": "text",
"text": question,
},
} for image_url in image_urls),
],
}],
*(
{
"type": "image_url",
"image_url": {"url": image_url},
}
for image_url in image_urls
),
],
}
],
sampling_params=sampling_params,
chat_template=req_data.chat_template,
lora_request=req_data.lora_requests,
......@@ -801,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str],
def parse_args():
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'vision language models that support multi-image input for text '
'generation')
parser.add_argument('--model-type',
'-m',
type=str,
default="phi3_v",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument("--method",
type=str,
default="generate",
choices=["generate", "chat"],
help="The method to run in `vllm.LLM`.")
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")
description="Demo on using vLLM for offline inference with "
"vision language models that support multi-image input for text "
"generation"
)
parser.add_argument(
"--model-type",
"-m",
type=str,
default="phi3_v",
choices=model_example_map.keys(),
help='Huggingface "model_type".',
)
parser.add_argument(
"--method",
type=str,
default="generate",
choices=["generate", "chat"],
help="The method to run in `vllm.LLM`.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument(
"--num-images",
"-n",
type=int,
choices=list(range(1,
len(IMAGE_URLS) + 1)), # the max number of images
choices=list(range(1, len(IMAGE_URLS) + 1)), # the max number of images
default=2,
help="Number of images to use for the demo.")
help="Number of images to use for the demo.",
)
return parser.parse_args()
......@@ -835,7 +830,7 @@ def main(args: Namespace):
method = args.method
seed = args.seed
image_urls = IMAGE_URLS[:args.num_images]
image_urls = IMAGE_URLS[: args.num_images]
if method == "generate":
run_generate(model, QUESTION, image_urls, seed)
......
......@@ -17,16 +17,15 @@ import requests
def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
LINE_UP = "\033[1A"
LINE_CLEAR = "\x1b[2K"
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
def post_http_request(
prompt: str, api_url: str, n: int = 1, stream: bool = False
) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
......@@ -35,17 +34,14 @@ def post_http_request(prompt: str,
"max_tokens": 16,
"stream": stream,
}
response = requests.post(api_url,
headers=headers,
json=pload,
stream=stream)
response = requests.post(api_url, headers=headers, json=pload, stream=stream)
return response
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\n"):
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
......
......@@ -6,6 +6,7 @@ Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base
"""
from typing import Union
import cohere
......@@ -16,28 +17,28 @@ model = "BAAI/bge-reranker-base"
query = "What is the capital of France?"
documents = [
"The capital of France is Paris", "Reranking is fun!",
"vLLM is an open-source framework for fast AI serving"
"The capital of France is Paris",
"Reranking is fun!",
"vLLM is an open-source framework for fast AI serving",
]
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str,
documents: list[str]) -> dict:
def cohere_rerank(
client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents)
def main():
# cohere v1 client
cohere_v1 = cohere.Client(base_url="http://localhost:8000",
api_key="sk-fake-key")
cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
print("-" * 50)
print("rerank_v1_result:\n", rerank_v1_result)
print("-" * 50)
# or the v2
cohere_v2 = cohere.ClientV2("sk-fake-key",
base_url="http://localhost:8000")
cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
print("rerank_v2_result:\n", rerank_v2_result)
print("-" * 50)
......
......@@ -13,6 +13,7 @@ launch this proxy demo through:
Note: This demo will be removed once the PDController implemented in PR 15343
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
"""
import argparse
import ipaddress
import itertools
......@@ -26,8 +27,7 @@ from typing import Callable, Optional
import aiohttp
import requests
import uvicorn
from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException,
Request, status)
from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse, StreamingResponse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
......@@ -36,24 +36,24 @@ logging.basicConfig(level=logging.INFO)
class SchedulingPolicy(ABC):
@abstractmethod
def schedule(self, cycler: itertools.cycle):
raise NotImplementedError("Scheduling Proxy is not set.")
class Proxy:
def __init__(
self,
prefill_instances: list[str],
decode_instances: list[str],
model: str,
scheduling_policy: SchedulingPolicy,
custom_create_completion: Optional[Callable[[Request],
StreamingResponse]] = None,
custom_create_chat_completion: Optional[Callable[
[Request], StreamingResponse]] = None,
custom_create_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
custom_create_chat_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
):
self.prefill_instances = prefill_instances
self.decode_instances = decode_instances
......@@ -68,30 +68,30 @@ class Proxy:
def setup_routes(self):
self.router.post(
"/v1/completions",
dependencies=[
Depends(self.validate_json_request)
])(self.custom_create_completion if self.
custom_create_completion else self.create_completion)
"/v1/completions", dependencies=[Depends(self.validate_json_request)]
)(
self.custom_create_completion
if self.custom_create_completion
else self.create_completion
)
self.router.post(
"/v1/chat/completions",
dependencies=[
Depends(self.validate_json_request)
])(self.custom_create_chat_completion if self.
custom_create_chat_completion else self.create_chat_completion)
self.router.get("/status",
response_class=JSONResponse)(self.get_status)
self.router.post("/instances/add",
dependencies=[Depends(self.api_key_authenticate)
])(self.add_instance_endpoint)
"/v1/chat/completions", dependencies=[Depends(self.validate_json_request)]
)(
self.custom_create_chat_completion
if self.custom_create_chat_completion
else self.create_chat_completion
)
self.router.get("/status", response_class=JSONResponse)(self.get_status)
self.router.post(
"/instances/add", dependencies=[Depends(self.api_key_authenticate)]
)(self.add_instance_endpoint)
async def validate_json_request(self, raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json":
raise HTTPException(
status_code=415,
detail=
"Unsupported Media Type: Only 'application/json' is allowed",
detail="Unsupported Media Type: Only 'application/json' is allowed",
)
def api_key_authenticate(self, x_api_key: str = Header(...)):
......@@ -103,8 +103,7 @@ class Proxy:
detail="Server configuration error.",
)
if x_api_key != expected_api_key:
logger.warning("Unauthorized access attempt with API Key: %s",
x_api_key)
logger.warning("Unauthorized access attempt with API Key: %s", x_api_key)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Forbidden: Invalid API Key.",
......@@ -113,8 +112,7 @@ class Proxy:
async def validate_instance(self, instance: str) -> bool:
url = f"http://{instance}/v1/models"
try:
async with aiohttp.ClientSession(
timeout=AIOHTTP_TIMEOUT) as client:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client:
logger.info("Verifying %s ...", instance)
async with client.get(url) as response:
if response.status == 200:
......@@ -122,12 +120,15 @@ class Proxy:
if "data" in data and len(data["data"]) > 0:
model_cur = data["data"][0].get("id", "")
if model_cur == self.model:
logger.info("Instance: %s could be added.",
instance)
logger.info("Instance: %s could be added.", instance)
return True
else:
logger.warning("Mismatch model %s : %s != %s",
instance, model_cur, self.model)
logger.warning(
"Mismatch model %s : %s != %s",
instance,
model_cur,
self.model,
)
return False
else:
return False
......@@ -147,48 +148,47 @@ class Proxy:
instance_type = data.get("type")
instance = data.get("instance")
if instance_type not in ["prefill", "decode"]:
raise HTTPException(status_code=400,
detail="Invalid instance type.")
raise HTTPException(status_code=400, detail="Invalid instance type.")
if not instance or ":" not in instance:
raise HTTPException(status_code=400,
detail="Invalid instance format.")
raise HTTPException(status_code=400, detail="Invalid instance format.")
host, port_str = instance.split(":")
try:
if host != "localhost":
ipaddress.ip_address(host)
port = int(port_str)
if not (0 < port < 65536):
raise HTTPException(status_code=400,
detail="Invalid port number.")
raise HTTPException(status_code=400, detail="Invalid port number.")
except Exception as e:
raise HTTPException(status_code=400,
detail="Invalid instance address.") from e
raise HTTPException(
status_code=400, detail="Invalid instance address."
) from e
is_valid = await self.validate_instance(instance)
if not is_valid:
raise HTTPException(status_code=400,
detail="Instance validation failed.")
raise HTTPException(
status_code=400, detail="Instance validation failed."
)
if instance_type == "prefill":
if instance not in self.prefill_instances:
self.prefill_instances.append(instance)
self.prefill_cycler = itertools.cycle(
self.prefill_instances)
self.prefill_cycler = itertools.cycle(self.prefill_instances)
else:
raise HTTPException(status_code=400,
detail="Instance already exists.")
raise HTTPException(
status_code=400, detail="Instance already exists."
)
else:
if instance not in self.decode_instances:
self.decode_instances.append(instance)
self.decode_cycler = itertools.cycle(self.decode_instances)
else:
raise HTTPException(status_code=400,
detail="Instance already exists.")
raise HTTPException(
status_code=400, detail="Instance already exists."
)
return JSONResponse(content={
"message":
f"Added {instance} to {instance_type}_instances."
})
return JSONResponse(
content={"message": f"Added {instance} to {instance_type}_instances."}
)
except HTTPException as http_exc:
raise http_exc
except Exception as e:
......@@ -197,16 +197,16 @@ class Proxy:
async def forward_request(self, url, data, use_chunked=True):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
try:
async with session.post(url=url, json=data,
headers=headers) as response:
async with session.post(
url=url, json=data, headers=headers
) as response:
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
if use_chunked:
async for chunk_bytes in response.content.iter_chunked( # noqa: E501
1024):
1024
):
yield chunk_bytes
else:
content = await response.read()
......@@ -217,20 +217,21 @@ class Proxy:
error_content = json.loads(error_content)
except json.JSONDecodeError:
error_content = error_content
logger.error("Request failed with status %s: %s",
response.status, error_content)
logger.error(
"Request failed with status %s: %s",
response.status,
error_content,
)
raise HTTPException(
status_code=response.status,
detail=
f"Request failed with status {response.status}: "
detail=f"Request failed with status {response.status}: "
f"{error_content}",
)
except aiohttp.ClientError as e:
logger.error("ClientError occurred: %s", str(e))
raise HTTPException(
status_code=502,
detail=
"Bad Gateway: Error communicating with upstream server.",
detail="Bad Gateway: Error communicating with upstream server.",
) from e
except Exception as e:
logger.error("Unexpected error: %s", str(e))
......@@ -258,8 +259,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler)
try:
async for _ in self.forward_request(
f"http://{prefill_instance}/v1/completions",
kv_prepare_request):
f"http://{prefill_instance}/v1/completions", kv_prepare_request
):
continue
except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance)
......@@ -270,7 +271,8 @@ class Proxy:
try:
generator = self.forward_request(
f"http://{decode_instance}/v1/completions", request)
f"http://{decode_instance}/v1/completions", request
)
except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance)
raise http_exc
......@@ -295,8 +297,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler)
try:
async for _ in self.forward_request(
f"http://{prefill_instance}/v1/chat/completions",
kv_prepare_request):
f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request
):
continue
except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance)
......@@ -306,8 +308,8 @@ class Proxy:
try:
generator = self.forward_request(
"http://" + decode_instance + "/v1/chat/completions",
request)
"http://" + decode_instance + "/v1/chat/completions", request
)
except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance)
raise http_exc
......@@ -318,20 +320,20 @@ class Proxy:
error_messages = [str(e) for e in exc_info if e]
print("Error occurred in disagg proxy server")
print(error_messages)
return StreamingResponse(content=iter(error_messages),
media_type="text/event-stream")
return StreamingResponse(
content=iter(error_messages), media_type="text/event-stream"
)
def remove_instance_endpoint(self, instance_type, instance):
if (instance_type == "decode" and instance in self.decode_instances):
if instance_type == "decode" and instance in self.decode_instances:
self.decode_instances.remove(instance)
self.decode_cycler = itertools.cycle(self.decode_instances)
if (instance_type == "prefill" and instance in self.decode_instances):
if instance_type == "prefill" and instance in self.decode_instances:
self.prefill_instances.remove(instance)
self.prefill_cycler = itertools.cycle(self.decode_instances)
class RoundRobinSchedulingPolicy(SchedulingPolicy):
def __init__(self):
super().__init__()
......@@ -340,15 +342,12 @@ class RoundRobinSchedulingPolicy(SchedulingPolicy):
class ProxyServer:
def __init__(
self,
args: argparse.Namespace,
scheduling_policy: Optional[SchedulingPolicy] = None,
create_completion: Optional[Callable[[Request],
StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request],
StreamingResponse]] = None,
create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
):
self.validate_parsed_serve_args(args)
self.port = args.port
......@@ -356,8 +355,11 @@ class ProxyServer:
prefill_instances=[] if args.prefill is None else args.prefill,
decode_instances=[] if args.decode is None else args.decode,
model=args.model,
scheduling_policy=(scheduling_policy if scheduling_policy
is not None else RoundRobinSchedulingPolicy()),
scheduling_policy=(
scheduling_policy
if scheduling_policy is not None
else RoundRobinSchedulingPolicy()
),
custom_create_completion=create_completion,
custom_create_chat_completion=create_chat_completion,
)
......@@ -382,11 +384,9 @@ class ProxyServer:
ipaddress.ip_address(host)
port = int(port)
if not (0 < port < 65536):
raise ValueError(
f"Invalid port number in instance: {instance}")
raise ValueError(f"Invalid port number in instance: {instance}")
except Exception as e:
raise ValueError(
f"Invalid instance {instance}: {str(e)}") from e
raise ValueError(f"Invalid instance {instance}: {str(e)}") from e
def verify_model_config(self, instances: list, model: str) -> None:
model_suffix = model.split("/")[-1]
......@@ -399,12 +399,14 @@ class ProxyServer:
if model_cur_suffix != model_suffix:
raise ValueError(
f"{instance} serves a different model: "
f"{model_cur} != {model}")
f"{model_cur} != {model}"
)
else:
raise ValueError(f"Cannot get model id from {instance}!")
except requests.RequestException as e:
raise ValueError(
f"Error communicating with {instance}: {str(e)}") from e
f"Error communicating with {instance}: {str(e)}"
) from e
def run_server(self):
app = FastAPI()
......@@ -417,11 +419,7 @@ class ProxyServer:
def parse_args():
# Todo: allow more config
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
parser.add_argument("--model",
"-m",
type=str,
required=True,
help="Model name")
parser.add_argument("--model", "-m", type=str, required=True, help="Model name")
parser.add_argument(
"--prefill",
......
......@@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
"""
import argparse
import gradio as gr
......@@ -24,16 +25,12 @@ from openai import OpenAI
def format_history_to_openai(history):
history_openai_format = [{
"role": "system",
"content": "You are a great AI assistant."
}]
history_openai_format = [
{"role": "system", "content": "You are a great AI assistant."}
]
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({
"role": "assistant",
"content": assistant
})
history_openai_format.append({"role": "assistant", "content": assistant})
return history_openai_format
......@@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
temperature=temp,
stream=True,
extra_body={
'repetition_penalty':
1,
'stop_token_ids':
[int(id.strip())
for id in stop_token_ids.split(',')] if stop_token_ids else []
})
"repetition_penalty": 1,
"stop_token_ids": [int(id.strip()) for id in stop_token_ids.split(",")]
if stop_token_ids
else [],
},
)
# Collect all chunks and concatenate them into a full message
full_message = ""
for chunk in stream:
full_message += (chunk.choices[0].delta.content or "")
full_message += chunk.choices[0].delta.content or ""
# Return the full message as a single response
return full_message
......@@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
def parse_args():
parser = argparse.ArgumentParser(
description='Chatbot Interface with Customizable Parameters')
parser.add_argument('--model-url',
type=str,
default='http://localhost:8000/v1',
help='Model URL')
parser.add_argument('-m',
'--model',
type=str,
required=True,
help='Model name for the chatbot')
parser.add_argument('--temp',
type=float,
default=0.8,
help='Temperature for text generation')
parser.add_argument('--stop-token-ids',
type=str,
default='',
help='Comma-separated stop token IDs')
description="Chatbot Interface with Customizable Parameters"
)
parser.add_argument(
"--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
)
parser.add_argument(
"-m", "--model", type=str, required=True, help="Model name for the chatbot"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="Temperature for text generation"
)
parser.add_argument(
"--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
)
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
return parser.parse_args()
def build_gradio_interface(client, model_name, temp, stop_token_ids):
def chat_predict(message, history):
return predict(message, history, client, model_name, temp,
stop_token_ids)
return predict(message, history, client, model_name, temp, stop_token_ids)
return gr.ChatInterface(fn=chat_predict,
title="Chatbot Interface",
description="A simple chatbot powered by vLLM")
return gr.ChatInterface(
fn=chat_predict,
title="Chatbot Interface",
description="A simple chatbot powered by vLLM",
)
def main():
......@@ -113,12 +106,13 @@ def main():
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
# Define the Gradio chatbot interface using the predict function
gradio_interface = build_gradio_interface(client, args.model, args.temp,
args.stop_token_ids)
gradio_interface = build_gradio_interface(
client, args.model, args.temp, args.stop_token_ids
)
gradio_interface.queue().launch(server_name=args.host,
server_port=args.port,
share=True)
gradio_interface.queue().launch(
server_name=args.host, server_port=args.port, share=True
)
if __name__ == "__main__":
......
......@@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
"""
import argparse
import json
......@@ -31,14 +32,11 @@ def http_bot(prompt):
"stream": True,
"max_tokens": 128,
}
response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\n"):
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
......@@ -48,10 +46,10 @@ def http_bot(prompt):
def build_demo():
with gr.Blocks() as demo:
gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input",
placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output",
placeholder="Generated result from the model")
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(
label="Output", placeholder="Generated result from the model"
)
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
......@@ -60,17 +58,15 @@ def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
parser.add_argument(
"--model-url", type=str, default="http://localhost:8000/generate"
)
return parser.parse_args()
def main(args):
demo = build_demo()
demo.queue().launch(server_name=args.host,
server_port=args.port,
share=True)
demo.queue().launch(server_name=args.host, server_port=args.port, share=True)
if __name__ == "__main__":
......
......@@ -5,6 +5,7 @@ Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base
"""
import json
import requests
......@@ -14,14 +15,13 @@ url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"}
data = {
"model":
"BAAI/bge-reranker-base",
"query":
"What is the capital of France?",
"model": "BAAI/bge-reranker-base",
"query": "What is the capital of France?",
"documents": [
"The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "Horses and cows are both animals"
]
"The capital of France is Paris.",
"Horses and cows are both animals",
],
}
......
......@@ -9,17 +9,14 @@ from msgspec.msgpack import Decoder
#
# Types copied from vllm.distributed.kv_events
#
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
gc=False):
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
ts: float
events: list[Any]
class KVCacheEvent(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False,
tag=True):
class KVCacheEvent(
msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
):
"""Base class for all KV cache-related events"""
......@@ -77,8 +74,9 @@ def main():
if last_seq >= 0 and seq > last_seq + 1:
missed = seq - last_seq - 1
print(f"Missed {missed} messages"
f" (last: {last_seq}, current: {seq})")
print(
f"Missed {missed} messages (last: {last_seq}, current: {seq})"
)
replay.send((last_seq + 1).to_bytes(8, "big"))
......
......@@ -12,26 +12,22 @@ from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
messages = [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Who won the world series in 2020?"},
{
"role": "assistant",
"content": "The Los Angeles Dodgers won the World Series in 2020.",
},
{"role": "user", "content": "Where was it played?"},
]
def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream",
action="store_true",
help="Enable streaming response")
parser.add_argument(
"--stream", action="store_true", help="Enable streaming response"
)
return parser.parse_args()
......
......@@ -43,7 +43,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
with requests.get(content_url) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')
result = base64.b64encode(response.content).decode("utf-8")
return result
......@@ -51,10 +51,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
# Text-only inference
def run_text_only(model: str) -> None:
chat_completion = client.chat.completions.create(
messages=[{
"role": "user",
"content": "What's the capital of France?"
}],
messages=[{"role": "user", "content": "What's the capital of France?"}],
model=model,
max_completion_tokens=64,
)
......@@ -65,26 +62,21 @@ def run_text_only(model: str) -> None:
# Single-image input inference
def run_single_image(model: str) -> None:
## Use image url in the payload
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": image_url
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": image_url},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -95,22 +87,18 @@ def run_single_image(model: str) -> None:
## Use base64 encoded image in the payload
image_base64 = encode_base64_content_from_url(image_url)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this image?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -124,28 +112,22 @@ def run_multi_image(model: str) -> None:
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What are the animals in these images?"
},
{
"type": "image_url",
"image_url": {
"url": image_url_duck
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What are the animals in these images?"},
{
"type": "image_url",
"image_url": {"url": image_url_duck},
},
},
{
"type": "image_url",
"image_url": {
"url": image_url_lion
{
"type": "image_url",
"image_url": {"url": image_url_lion},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -161,22 +143,18 @@ def run_video(model: str) -> None:
## Use video url in the payload
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this video?"
},
{
"type": "video_url",
"video_url": {
"url": video_url
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this video?"},
{
"type": "video_url",
"video_url": {"url": video_url},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -186,22 +164,18 @@ def run_video(model: str) -> None:
## Use base64 encoded video in the payload
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this video?"
},
{
"type": "video_url",
"video_url": {
"url": f"data:video/mp4;base64,{video_base64}"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this video?"},
{
"type": "video_url",
"video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -219,24 +193,22 @@ def run_audio(model: str) -> None:
# OpenAI-compatible schema (`input_audio`)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "input_audio",
"input_audio": {
# Any format supported by librosa is supported
"data": audio_base64,
"format": "wav"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this audio?"},
{
"type": "input_audio",
"input_audio": {
# Any format supported by librosa is supported
"data": audio_base64,
"format": "wav",
},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -246,23 +218,21 @@ def run_audio(model: str) -> None:
# HTTP URL
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": audio_url
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this audio?"},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": audio_url
},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -272,23 +242,21 @@ def run_audio(model: str) -> None:
# base64 URL
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this audio?"},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
},
},
},
],
}],
],
}
],
model=model,
max_completion_tokens=64,
)
......@@ -308,14 +276,17 @@ example_function_map = {
def parse_args():
parser = FlexibleArgumentParser(
description='Demo on using OpenAI client for online serving with '
'multimodal language models served with vLLM.')
parser.add_argument('--chat-type',
'-c',
type=str,
default="single-image",
choices=list(example_function_map.keys()),
help='Conversation type with multimodal data.')
description="Demo on using OpenAI client for online serving with "
"multimodal language models served with vLLM."
)
parser.add_argument(
"--chat-type",
"-c",
type=str,
default="single-image",
choices=list(example_function_map.keys()),
help="Conversation type with multimodal data.",
)
return parser.parse_args()
......
......@@ -16,6 +16,7 @@ vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \
--chat-template examples/tool_chat_template_hermes.jinja \
--enable-auto-tool-choice --tool-call-parser hermes
"""
import json
from typing import Any
......@@ -25,55 +26,55 @@ from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
properties = {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": properties,
"required": ["city", "state", "unit"],
},
"required": ["city", "state", "unit"]
}
},
}
}]
messages = [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I'm doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.")
]
messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{
"role": "user",
"content": (
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
),
},
]
def get_current_weather(city: str, state: str, unit: "str"):
return (
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
def handle_tool_calls_stream(
......@@ -82,10 +83,9 @@ def handle_tool_calls_stream(
model: str,
tools: list[dict[str, Any]],
) -> list[Any]:
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=True)
tool_calls_stream = client.chat.completions.create(
messages=messages, model=model, tools=tools, stream=True
)
chunks = []
print("chunks: ")
for chunk in tool_calls_stream:
......@@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx:
if tool_call_idx >= 0:
print(f"streamed tool call arguments: "
f"{arguments[tool_call_idx]}")
print(f"streamed tool call arguments: {arguments[tool_call_idx]}")
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("")
if tool_call.id:
......@@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
if tool_call.function:
if tool_call.function.name:
print(
f"streamed tool call name: {tool_call.function.name}")
print(f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments
......@@ -136,9 +134,9 @@ def main():
models = client.models.list()
model = models.data[0].id
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
chat_completion = client.chat.completions.create(
messages=messages, model=model, tools=tools
)
print("-" * 70)
print("Chat completion results:")
......@@ -158,10 +156,12 @@ def main():
print("-" * 70)
# Add tool call results to the conversation
messages.append({
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls
})
messages.append(
{
"role": "assistant",
"tool_calls": chat_completion.choices[0].message.tool_calls,
}
)
# Now, simulate a tool call
available_tools = {"get_current_weather": get_current_weather}
......@@ -172,17 +172,18 @@ def main():
args = json.loads(call.function.arguments)
result = tool_to_call(**args)
print("tool_to_call result: ", result)
messages.append({
"role": "tool",
"content": result,
"tool_call_id": call.id,
"name": call.function.name
})
chat_completion_2 = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=False)
messages.append(
{
"role": "tool",
"content": result,
"tool_call_id": call.id,
"name": call.function.name,
}
)
chat_completion_2 = client.chat.completions.create(
messages=messages, model=model, tools=tools, stream=False
)
print("Chat completion2 results:")
print(chat_completion_2)
print("-" * 70)
......
......@@ -28,18 +28,16 @@ tools = [
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for"
"type": "string",
"description": "The city to find the weather for"
", e.g. 'San Francisco'",
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the "
"city is in, e.g. 'CA' which would mean 'California'",
"type": "string",
"description": (
"the two-letter abbreviation for the state that the "
"city is in, e.g. 'CA' which would mean 'California'"
),
},
"unit": {
"type": "string",
......@@ -60,22 +58,20 @@ tools = [
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to get the forecast for, e.g. 'New York'",
"type": "string",
"description": (
"The city to get the forecast for, e.g. 'New York'"
),
},
"state": {
"type":
"string",
"description":
"The two-letter abbreviation for the state, e.g. 'NY'",
"type": "string",
"description": (
"The two-letter abbreviation for the state, e.g. 'NY'"
),
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
"type": "integer",
"description": "Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
......@@ -90,19 +86,11 @@ tools = [
]
messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Dallas \
"content": "Can you tell me what the current weather is in Dallas \
and the forecast for the next 5 days, in fahrenheit?",
},
]
......@@ -123,17 +111,16 @@ def main():
model=model,
tools=tools,
tool_choice="required",
stream=True # Enable streaming response
stream=True, # Enable streaming response
)
for chunk in chat_completion:
if chunk.choices and chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls)
chat_completion = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
tool_choice="required")
chat_completion = client.chat.completions.create(
messages=messages, model=model, tools=tools, tool_choice="required"
)
print(chat_completion.choices[0].message.tool_calls)
......
......@@ -20,10 +20,9 @@ openai_api_base = "http://localhost:8000/v1"
def guided_choice_completion(client: OpenAI, model: str):
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": "Classify this sentiment: vLLM is wonderful!"
}],
messages=[
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
],
extra_body={"guided_choice": ["positive", "negative"]},
)
return completion.choices[0].message.content
......@@ -31,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str):
# Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n")
prompt = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={
"guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"]
},
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
)
return completion.choices[0].message.content
......@@ -66,14 +66,18 @@ class CarDescription(BaseModel):
def guided_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's")
prompt = (
"Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_json": json_schema},
)
return completion.choices[0].message.content
......@@ -95,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
number ::= "1 " | "2 "
"""
prompt = ("Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table.")
prompt = (
"Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar},
)
return completion.choices[0].message.content
......@@ -110,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str):
# Extra backend options
def extra_backend_options_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n")
prompt = (
"Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
try:
# The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={
"guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"],
......
......@@ -17,11 +17,10 @@ def main():
api_key=openai_api_key,
)
messages = [{
"role":
"user",
"content":
"""
messages = [
{
"role": "user",
"content": """
You have access to the following function to retrieve the weather in a city:
{
......@@ -58,29 +57,28 @@ You are a helpful assistant.
Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco?
"""
}]
""",
}
]
response = client.chat.completions.create(
model=client.models.list().data[0].id,
messages=messages,
response_format={
"type":
"structural_tag",
"structures": [{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {
"city": {
"type": "string"
}
}
},
"end": "</function>"
}],
"triggers": ["<function="]
})
"type": "structural_tag",
"structures": [
{
"begin": "<function=get_weather>",
"schema": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
"end": "</function>",
}
],
"triggers": ["<function="],
},
)
print(response)
......
......@@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1"
def print_completion_details(completion):
print("reasoning_content: ",
completion.choices[0].message.reasoning_content)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)
# Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str):
prompt = ("What is the capital of France?")
prompt = "What is the capital of France?"
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={
"guided_regex": "(Paris|London)",
},
......@@ -57,13 +58,15 @@ class People(BaseModel):
def guided_json_completion(client: OpenAI, model: str):
json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.")
prompt = "Generate a JSON with the name and age of one random person."
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_json": json_schema},
)
print_completion_details(completion)
......@@ -86,14 +89,18 @@ class CarDescription(BaseModel):
def guided_car_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's")
prompt = (
"Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_json": json_schema},
)
print_completion_details(completion)
......@@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
"""
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
prompt = ("Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table.")
prompt = (
"Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
messages=[
{
"role": "user",
"content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar},
)
print_completion_details(completion)
......
......@@ -20,9 +20,11 @@ from openai import OpenAI
# Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's.")
def get_current_weather(city: str, state: str, unit: "str"):
return (
"The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather}
......@@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather}
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
properties = {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'",
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
}
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": properties,
"required": ["city", "state", "unit"],
},
"required": ["city", "state", "unit"]
}
},
}
}]
messages = [{
"role": "user",
"content": "Hi! How are you doing today?"
}, {
"role": "assistant",
"content": "I'm doing well! How can I help you?"
}, {
"role":
"user",
"content":
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
]
messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{
"role": "user",
"content": (
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
),
},
]
def extract_reasoning_and_calls(chunks: list):
......@@ -110,73 +110,55 @@ def main():
models = client.models.list()
model = models.data[0].id
print("---------Full Generate With Automatic Function Calling-------------")
tool_calls = client.chat.completions.create(
messages=messages, model=model, tools=tools
)
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
print(
"---------Full Generate With Automatic Function Calling-------------")
tool_calls = client.chat.completions.create(messages=messages,
model=model,
tools=tools)
print(
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
f"function arguments: "
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}"
)
print(f"function name: "
f"{tool_calls.choices[0].message.tool_calls[0].function.name}")
print(f"function arguments: "
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
print(
"----------Stream Generate With Automatic Function Calling-----------")
tool_calls_stream = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
stream=True)
print("----------Stream Generate With Automatic Function Calling-----------")
tool_calls_stream = client.chat.completions.create(
messages=messages, model=model, tools=tools, stream=True
)
chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
chunks)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}")
print(
"----------Full Generate With Named Function Calling-----------------")
tool_calls = client.chat.completions.create(messages=messages,
model=model,
tools=tools,
tool_choice={
"type": "function",
"function": {
"name":
"get_current_weather"
}
})
print("----------Full Generate With Named Function Calling-----------------")
tool_calls = client.chat.completions.create(
messages=messages,
model=model,
tools=tools,
tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
)
tool_call = tool_calls.choices[0].message.tool_calls[0].function
print(
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
)
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"function name: {tool_call.name}")
print(f"function arguments: {tool_call.arguments}")
print(
"----------Stream Generate With Named Function Calling--------------")
print("----------Stream Generate With Named Function Calling--------------")
tool_calls_stream = client.chat.completions.create(
messages=messages,
model=model,
tools=tools,
tool_choice={
"type": "function",
"function": {
"name": "get_current_weather"
}
},
stream=True)
tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
stream=True,
)
chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(
chunks)
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment