"vllm/vscode:/vscode.git/clone" did not exist on "fe921763212b881d9629d04c2eaab4496e136fa5"
Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
...@@ -45,8 +45,7 @@ if dist.get_rank() == 0: ...@@ -45,8 +45,7 @@ if dist.get_rank() == 0:
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n")
f"Generated text: {generated_text!r}\n")
print("-" * 50) print("-" * 50)
""" """
Further tips: Further tips:
......
...@@ -20,10 +20,12 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16) ...@@ -20,10 +20,12 @@ sampling_params = SamplingParams(temperature=0, top_p=1.0, n=N, max_tokens=16)
def main(): def main():
# Set `enforce_eager=True` to avoid ahead-of-time compilation. # Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`. # In real workloads, `enforace_eager` should be `False`.
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", llm = LLM(
max_num_batched_tokens=64, model="Qwen/Qwen2-1.5B-Instruct",
max_num_seqs=4, max_num_batched_tokens=64,
max_model_len=128) max_num_seqs=4,
max_model_len=128,
)
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
print("-" * 50) print("-" * 50)
for output, answer in zip(outputs, answers): for output, answer in zip(outputs, answers):
......
...@@ -6,6 +6,7 @@ the correct prompt format on vision language models for text generation. ...@@ -6,6 +6,7 @@ the correct prompt format on vision language models for text generation.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
import os import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
...@@ -49,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: ...@@ -49,9 +50,13 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [(f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}" prompts = [
"<|im_end|>\n<|im_start|>assistant\n") (
for question in questions] f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
"<|im_end|>\n<|im_start|>assistant\n"
)
for question in questions
]
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
...@@ -135,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ...@@ -135,8 +140,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
) )
prompts = [ prompts = [
f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" f"<|User|>: <image>\n{question}\n\n<|Assistant|>:" for question in questions
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
...@@ -198,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: ...@@ -198,9 +202,14 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [("<bos><start_of_turn>user\n" prompts = [
f"<start_of_image>{question}<end_of_turn>\n" (
"<start_of_turn>model\n") for question in questions] "<bos><start_of_turn>user\n"
f"<start_of_image>{question}<end_of_turn>\n"
"<start_of_turn>model\n"
)
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -225,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: ...@@ -225,7 +234,8 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\
{question}<|assistant|>" for question in questions {question}<|assistant|>"
for question in questions
] ]
stop_token_ids = [151329, 151336, 151338] stop_token_ids = [151329, 151336, 151338]
...@@ -250,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -250,15 +260,13 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi # Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m # https://huggingface.co/h2oai/h2ovl-mississippi-800m
...@@ -284,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: ...@@ -284,15 +292,14 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge". # if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={ mm_processor_kwargs={
"size": { "size": {"longest_edge": 3 * 364},
"longest_edge": 3 * 364
},
}, },
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
prompts = [( prompts = [
f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:" (f"<|begin_of_text|>User:<image>{question}<end_of_utterance>\nAssistant:")
) for question in questions] for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -311,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: ...@@ -311,9 +318,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
max_num_seqs=2, max_num_seqs=2,
enforce_eager=True, enforce_eager=True,
mm_processor_kwargs={ mm_processor_kwargs={
"max_image_size": { "max_image_size": {"longest_edge": 384},
"longest_edge": 384
},
}, },
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
...@@ -330,7 +335,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: ...@@ -330,7 +335,6 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData:
# InternVL # InternVL
def run_internvl(questions: list[str], modality: str) -> ModelRequestData: def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "OpenGVLab/InternVL3-2B" model_name = "OpenGVLab/InternVL3-2B"
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -345,15 +349,14 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -345,15 +349,14 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video": elif modality == "video":
placeholder = "<video>" placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"{placeholder}\n{question}"}]
'role': 'user', for question in questions
'content': f"{placeholder}\n{question}" ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for InternVL # Stop tokens for InternVL
# models variants may have different stop tokens # models variants may have different stop tokens
...@@ -361,9 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -361,9 +364,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData:
# https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
stop_token_ids = [ stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None]
token_id for token_id in stop_token_ids if token_id is not None
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -379,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -379,7 +380,8 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
"<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>" "<|im_user|>user<|im_middle|><|media_start|>image<|media_content|>"
f"<|media_pad|><|media_end|>{question}<|im_end|>" f"<|media_pad|><|media_end|>{question}<|im_end|>"
"<|im_assistant|>assistant<|im_middle|>" for question in questions "<|im_assistant|>assistant<|im_middle|>"
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -399,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -399,9 +401,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
def run_llava(questions: list[str], modality: str) -> ModelRequestData: def run_llava(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
prompts = [ prompts = [f"USER: <image>\n{question}\nASSISTANT:" for question in questions]
f"USER: <image>\n{question}\nASSISTANT:" for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/llava-1.5-7b-hf", model="llava-hf/llava-1.5-7b-hf",
...@@ -434,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: ...@@ -434,13 +434,10 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData:
# LlaVA-NeXT-Video # LlaVA-NeXT-Video
# Currently only support for video input # Currently only support for video input
def run_llava_next_video(questions: list[str], def run_llava_next_video(questions: list[str], modality: str) -> ModelRequestData:
modality: str) -> ModelRequestData:
assert modality == "video" assert modality == "video"
prompts = [ prompts = [f"USER: <video>\n{question} ASSISTANT:" for question in questions]
f"USER: <video>\n{question} ASSISTANT:" for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="llava-hf/LLaVA-NeXT-Video-7B-hf", model="llava-hf/LLaVA-NeXT-Video-7B-hf",
max_model_len=8192, max_model_len=8192,
...@@ -455,19 +452,19 @@ def run_llava_next_video(questions: list[str], ...@@ -455,19 +452,19 @@ def run_llava_next_video(questions: list[str],
# LLaVA-OneVision # LLaVA-OneVision
def run_llava_onevision(questions: list[str], def run_llava_onevision(questions: list[str], modality: str) -> ModelRequestData:
modality: str) -> ModelRequestData:
if modality == "video": if modality == "video":
prompts = [ prompts = [
f"<|im_start|>user <video>\n{question}<|im_end|> \ f"<|im_start|>user <video>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
elif modality == "image": elif modality == "image":
prompts = [ prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \ f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -486,11 +483,8 @@ def run_llava_onevision(questions: list[str], ...@@ -486,11 +483,8 @@ def run_llava_onevision(questions: list[str],
def run_mantis(questions: list[str], modality: str) -> ModelRequestData: def run_mantis(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n' # noqa: E501 llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" # noqa: E501
prompts = [ prompts = [llama3_template.format(f"{question}\n<image>") for question in questions]
llama3_template.format(f"{question}\n<image>")
for question in questions
]
engine_args = EngineArgs( engine_args = EngineArgs(
model="TIGER-Lab/Mantis-8B-siglip-llama3", model="TIGER-Lab/Mantis-8B-siglip-llama3",
...@@ -530,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): ...@@ -530,8 +524,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# 2.6: image, video # 2.6: image, video
# o2.6: image, video, audio # o2.6: image, video, audio
# model_name = "openbmb/MiniCPM-o-2_6" # model_name = "openbmb/MiniCPM-o-2_6"
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=4096, max_model_len=4096,
...@@ -547,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): ...@@ -547,7 +540,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
# stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] # stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
# 2.6 / o2.6 # 2.6 / o2.6
stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_tokens = ["<|im_end|>", "<|endoftext|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
modality_placeholder = { modality_placeholder = {
...@@ -557,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): ...@@ -557,12 +550,16 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name):
prompts = [ prompts = [
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
[{ [
'role': 'user', {
'content': f"{modality_placeholder[modality]}\n{question}" "role": "user",
}], "content": f"{modality_placeholder[modality]}\n{question}",
}
],
tokenize=False, tokenize=False,
add_generation_prompt=True) for question in questions add_generation_prompt=True,
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
...@@ -622,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: ...@@ -622,19 +619,18 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
) )
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{ messages = [
"role": [
"user", {
"content": [{ "role": "user",
"type": "image" "content": [{"type": "image"}, {"type": "text", "text": question}],
}, { }
"type": "text", ]
"text": question for question in questions
}] ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, add_generation_prompt=True, tokenize=False
add_generation_prompt=True, )
tokenize=False)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -657,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: ...@@ -657,19 +653,18 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
) )
tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{ messages = [
"role": [
"user", {
"content": [{ "role": "user",
"type": "image" "content": [{"type": "image"}, {"type": "text", "text": f"{question}"}],
}, { }
"type": "text", ]
"text": f"{question}" for question in questions
}] ]
}] for question in questions] prompts = tokenizer.apply_chat_template(
prompts = tokenizer.apply_chat_template(messages, messages, add_generation_prompt=True, tokenize=False
add_generation_prompt=True, )
tokenize=False)
stop_token_ids = None stop_token_ids = None
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -693,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: ...@@ -693,7 +688,8 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
prompts = [ prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|> \ f"<|im_start|>user <image>\n{question}<|im_end|> \
<|im_start|>assistant\n" for question in questions <|im_start|>assistant\n"
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
...@@ -717,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: ...@@ -717,15 +713,13 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -748,15 +742,13 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData: ...@@ -748,15 +742,13 @@ def run_ovis(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -847,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: ...@@ -847,8 +839,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData:
# we have to manually specify the path of the lora weights. # we have to manually specify the path of the lora weights.
vision_lora_path = os.path.join(model_path, "vision-lora") vision_lora_path = os.path.join(model_path, "vision-lora")
prompts = [ prompts = [
f"<|user|><|image_1|>{question}<|end|><|assistant|>" f"<|user|><|image_1|>{question}<|end|><|assistant|>" for question in questions
for question in questions
] ]
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_path, model=model_path,
...@@ -915,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -915,7 +906,6 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2-VL # Qwen2-VL
def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -936,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -936,10 +926,13 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>" placeholder = "<|video_pad|>"
prompts = [ prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" (
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"{question}<|im_end|>\n" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
"<|im_start|>assistant\n") for question in questions f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
...@@ -950,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -950,7 +943,6 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData:
# Qwen2.5-VL # Qwen2.5-VL
def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "Qwen/Qwen2.5-VL-3B-Instruct" model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -971,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -971,10 +963,13 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
placeholder = "<|video_pad|>" placeholder = "<|video_pad|>"
prompts = [ prompts = [
("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" (
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"{question}<|im_end|>\n" f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
"<|im_start|>assistant\n") for question in questions f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
] ]
return ModelRequestData( return ModelRequestData(
...@@ -1007,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ...@@ -1007,12 +1002,18 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n" prompts = [
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>" (
f"{question}<|im_end|>\n" f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions] f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompts=prompts, prompts=prompts,
...@@ -1032,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: ...@@ -1032,15 +1033,13 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) messages = [
messages = [[{ [{"role": "user", "content": f"<image>\n{question}"}] for question in questions
'role': 'user', ]
'content': f"<image>\n{question}" prompts = tokenizer.apply_chat_template(
}] for question in questions] messages, tokenize=False, add_generation_prompt=True
prompts = tokenizer.apply_chat_template(messages, )
tokenize=False,
add_generation_prompt=True)
# Stop tokens for SkyworkR1V # Stop tokens for SkyworkR1V
# https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py # https://huggingface.co/Skywork/Skywork-R1V-38B/blob/main/conversation.py
...@@ -1104,8 +1103,7 @@ def get_multi_modal_input(args): ...@@ -1104,8 +1103,7 @@ def get_multi_modal_input(args):
""" """
if args.modality == "image": if args.modality == "image":
# Input image and question # Input image and question
image = convert_image_mode( image = convert_image_mode(ImageAsset("cherry_blossom").pil_image, "RGB")
ImageAsset("cherry_blossom").pil_image, "RGB")
img_questions = [ img_questions = [
"What is the content of this image?", "What is the content of this image?",
"Describe the content of this image in detail.", "Describe the content of this image in detail.",
...@@ -1120,8 +1118,7 @@ def get_multi_modal_input(args): ...@@ -1120,8 +1118,7 @@ def get_multi_modal_input(args):
if args.modality == "video": if args.modality == "video":
# Input video and question # Input video and question
video = VideoAsset(name="baby_reading", video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays
num_frames=args.num_frames).np_ndarrays
vid_questions = ["Why is this video funny?"] vid_questions = ["Why is this video funny?"]
return { return {
...@@ -1133,12 +1130,13 @@ def get_multi_modal_input(args): ...@@ -1133,12 +1130,13 @@ def get_multi_modal_input(args):
raise ValueError(msg) raise ValueError(msg)
def apply_image_repeat(image_repeat_prob, num_prompts, data, def apply_image_repeat(
prompts: list[str], modality): image_repeat_prob, num_prompts, data, prompts: list[str], modality
"""Repeats images with provided probability of "image_repeat_prob". ):
"""Repeats images with provided probability of "image_repeat_prob".
Used to simulate hit/miss for the MM preprocessor cache. Used to simulate hit/miss for the MM preprocessor cache.
""" """
assert (image_repeat_prob <= 1.0 and image_repeat_prob >= 0) assert image_repeat_prob <= 1.0 and image_repeat_prob >= 0
no_yes = [0, 1] no_yes = [0, 1]
probs = [1.0 - image_repeat_prob, image_repeat_prob] probs = [1.0 - image_repeat_prob, image_repeat_prob]
...@@ -1153,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, ...@@ -1153,12 +1151,12 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
new_val = (i // 256 // 256, i // 256, i % 256) new_val = (i // 256 // 256, i // 256, i % 256)
cur_image.putpixel((0, 0), new_val) cur_image.putpixel((0, 0), new_val)
inputs.append({ inputs.append(
"prompt": prompts[i % len(prompts)], {
"multi_modal_data": { "prompt": prompts[i % len(prompts)],
modality: cur_image "multi_modal_data": {modality: cur_image},
} }
}) )
return inputs return inputs
...@@ -1167,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data, ...@@ -1167,6 +1165,7 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
def time_counter(enable: bool): def time_counter(enable: bool):
if enable: if enable:
import time import time
start_time = time.time() start_time = time.time()
yield yield
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
...@@ -1179,54 +1178,65 @@ def time_counter(enable: bool): ...@@ -1179,54 +1178,65 @@ def time_counter(enable: bool):
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for text generation') "vision language models for text generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="llava", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="llava",
parser.add_argument('--num-prompts', choices=model_example_map.keys(),
type=int, help='Huggingface "model_type".',
default=4, )
help='Number of prompts to run.') parser.add_argument(
parser.add_argument('--modality', "--num-prompts", type=int, default=4, help="Number of prompts to run."
type=str, )
default="image", parser.add_argument(
choices=['image', 'video'], "--modality",
help='Modality of the input.') type=str,
parser.add_argument('--num-frames', default="image",
type=int, choices=["image", "video"],
default=16, help="Modality of the input.",
help='Number of frames to extract from the video.') )
parser.add_argument("--seed", parser.add_argument(
type=int, "--num-frames",
default=None, type=int,
help="Set the seed when initializing `vllm.LLM`.") default=16,
help="Number of frames to extract from the video.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument( parser.add_argument(
'--image-repeat-prob', "--image-repeat-prob",
type=float, type=float,
default=None, default=None,
help='Simulates the hit-ratio for multi-modal preprocessor cache' help="Simulates the hit-ratio for multi-modal preprocessor cache (if enabled)",
' (if enabled)') )
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', "--disable-mm-preprocessor-cache",
action='store_true', action="store_true",
help='If True, disables caching of multi-modal preprocessor/mapper.') help="If True, disables caching of multi-modal preprocessor/mapper.",
)
parser.add_argument( parser.add_argument(
'--time-generate', "--time-generate",
action='store_true', action="store_true",
help='If True, then print the total generate() call time') help="If True, then print the total generate() call time",
)
parser.add_argument( parser.add_argument(
'--use-different-prompt-per-request', "--use-different-prompt-per-request",
action='store_true', action="store_true",
help='If True, then use different prompt (with the same multi-modal ' help="If True, then use different prompt (with the same multi-modal "
'data) for each request.') "data) for each request.",
)
return parser.parse_args() return parser.parse_args()
...@@ -1245,7 +1255,8 @@ def main(args): ...@@ -1245,7 +1255,8 @@ def main(args):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | { engine_args = asdict(req_data.engine_args) | {
"seed": args.seed, "seed": args.seed,
...@@ -1254,44 +1265,46 @@ def main(args): ...@@ -1254,44 +1265,46 @@ def main(args):
llm = LLM(**engine_args) llm = LLM(**engine_args)
# Don't want to check the flag multiple times, so just hijack `prompts`. # Don't want to check the flag multiple times, so just hijack `prompts`.
prompts = req_data.prompts if args.use_different_prompt_per_request else [ prompts = (
req_data.prompts[0] req_data.prompts
] if args.use_different_prompt_per_request
else [req_data.prompts[0]]
)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, sampling_params = SamplingParams(
max_tokens=64, temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
assert args.num_prompts > 0 assert args.num_prompts > 0
if args.num_prompts == 1: if args.num_prompts == 1:
# Single inference # Single inference
inputs = { inputs = {
"prompt": prompts[0], "prompt": prompts[0],
"multi_modal_data": { "multi_modal_data": {modality: data},
modality: data
},
} }
else: else:
# Batch inference # Batch inference
if args.image_repeat_prob is not None: if args.image_repeat_prob is not None:
# Repeat images with specified probability of "image_repeat_prob" # Repeat images with specified probability of "image_repeat_prob"
inputs = apply_image_repeat(args.image_repeat_prob, inputs = apply_image_repeat(
args.num_prompts, data, prompts, args.image_repeat_prob, args.num_prompts, data, prompts, modality
modality) )
else: else:
# Use the same image for all prompts # Use the same image for all prompts
inputs = [{ inputs = [
"prompt": prompts[i % len(prompts)], {
"multi_modal_data": { "prompt": prompts[i % len(prompts)],
modality: data "multi_modal_data": {modality: data},
}, }
} for i in range(args.num_prompts)] for i in range(args.num_prompts)
]
# Add LoRA request if applicable # Add LoRA request if applicable
lora_request = (req_data.lora_requests * lora_request = (
args.num_prompts if req_data.lora_requests else None) req_data.lora_requests * args.num_prompts if req_data.lora_requests else None
)
with time_counter(args.time_generate): with time_counter(args.time_generate):
outputs = llm.generate( outputs = llm.generate(
......
...@@ -6,6 +6,7 @@ the correct prompt format on vision language models for multimodal embedding. ...@@ -6,6 +6,7 @@ the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples For most models, the prompt format should follow corresponding examples
on HuggingFace model repository. on HuggingFace model repository.
""" """
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args
...@@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple): ...@@ -44,19 +45,17 @@ class ModelRequestData(NamedTuple):
def run_e5_v(query: Query) -> ModelRequestData: def run_e5_v(query: Query) -> ModelRequestData:
llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 llama3_template = "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n" # noqa: E501
if query["modality"] == "text": if query["modality"] == "text":
text = query["text"] text = query["text"]
prompt = llama3_template.format( prompt = llama3_template.format(f"{text}\nSummary above sentence in one word: ")
f"{text}\nSummary above sentence in one word: ")
image = None image = None
elif query["modality"] == "image": elif query["modality"] == "image":
prompt = llama3_template.format( prompt = llama3_template.format("<image>\nSummary above image in one word: ")
"<image>\nSummary above image in one word: ")
image = query["image"] image = query["image"]
else: else:
modality = query['modality'] modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'") raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData: ...@@ -83,10 +82,12 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
image = query["image"] image = query["image"]
elif query["modality"] == "text+image": elif query["modality"] == "text+image":
text = query["text"] text = query["text"]
prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 prompt = (
f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
)
image = query["image"] image = query["image"]
else: else:
modality = query['modality'] modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'") raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -136,7 +137,8 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
...@@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -145,10 +147,12 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
if req_data.image is not None: if req_data.image is not None:
mm_data["image"] = req_data.image mm_data["image"] = req_data.image
outputs = llm.embed({ outputs = llm.embed(
"prompt": req_data.prompt, {
"multi_modal_data": mm_data, "prompt": req_data.prompt,
}) "multi_modal_data": mm_data,
}
)
print("-" * 50) print("-" * 50)
for output in outputs: for output in outputs:
...@@ -164,23 +168,30 @@ model_example_map = { ...@@ -164,23 +168,30 @@ model_example_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models for multimodal embedding') "vision language models for multimodal embedding"
parser.add_argument('--model-name', )
'-m', parser.add_argument(
type=str, "--model-name",
default="vlm2vec", "-m",
choices=model_example_map.keys(), type=str,
help='The name of the embedding model.') default="vlm2vec",
parser.add_argument('--modality', choices=model_example_map.keys(),
type=str, help="The name of the embedding model.",
default="image", )
choices=get_args(QueryModality), parser.add_argument(
help='Modality of the input.') "--modality",
parser.add_argument("--seed", type=str,
type=int, default="image",
default=None, choices=get_args(QueryModality),
help="Set the seed when initializing `vllm.LLM`.") help="Modality of the input.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -4,6 +4,7 @@ This example shows how to use vLLM for running offline inference with ...@@ -4,6 +4,7 @@ This example shows how to use vLLM for running offline inference with
multi-image input on vision language models for text generation, multi-image input on vision language models for text generation,
using the chat template defined by the model. using the chat template defined by the model.
""" """
import os import os
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
...@@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -59,8 +60,9 @@ def load_aria(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls) placeholders = "<fim_prefix><|img|><fim_suffix>\n" * len(image_urls)
prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" prompt = (
"<|im_start|>assistant\n") f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n<|im_start|>assistant\n"
)
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
return ModelRequestData( return ModelRequestData(
...@@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -81,23 +83,21 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -106,8 +106,7 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_deepseek_vl2(question: str, def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
image_urls: list[str]) -> ModelRequestData:
model_name = "deepseek-ai/deepseek-vl2-tiny" model_name = "deepseek-ai/deepseek-vl2-tiny"
engine_args = EngineArgs( engine_args = EngineArgs(
...@@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str, ...@@ -118,8 +117,9 @@ def load_deepseek_vl2(question: str,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholder = "".join(f"image_{i}:<image>\n" placeholder = "".join(
for i, _ in enumerate(image_urls, start=1)) f"image_{i}:<image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:" prompt = f"<|User|>: {placeholder}{question}\n\n<|Assistant|>:"
return ModelRequestData( return ModelRequestData(
...@@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -140,23 +140,21 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -176,15 +174,15 @@ def load_h2ovl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for H2OVL-Mississippi # Stop tokens for H2OVL-Mississippi
# https://huggingface.co/h2oai/h2ovl-mississippi-800m # https://huggingface.co/h2oai/h2ovl-mississippi-800m
...@@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -211,14 +209,13 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
# if you are running out of memory, you can reduce the "longest_edge". # if you are running out of memory, you can reduce the "longest_edge".
# see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations # see: https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3#model-optimizations
mm_processor_kwargs={ mm_processor_kwargs={
"size": { "size": {"longest_edge": 2 * 364},
"longest_edge": 2 * 364
},
}, },
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 prompt = f"<|begin_of_text|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -238,15 +235,16 @@ def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
enforce_eager=True, enforce_eager=True,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={ mm_processor_kwargs={
"max_image_size": { "max_image_size": {"longest_edge": 384},
"longest_edge": 384
},
}, },
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
prompt = f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501 )
prompt = (
f"<|im_start|>User:{placeholders}\n{question}<end_of_utterance>\nAssistant:" # noqa: E501
)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
...@@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -265,15 +263,15 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
# Stop tokens for InternVL # Stop tokens for InternVL
# models variants may have different stop tokens # models variants may have different stop tokens
...@@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -301,23 +299,21 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -338,24 +334,21 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": [ "content": [
*placeholders, *placeholders,
{ {"type": "text", "text": question},
"type": "text", ],
"text": question }
}, ]
],
}]
processor = AutoProcessor.from_pretrained(model_name, processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -419,15 +412,15 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"max_dynamic_patch": 4}, mm_processor_kwargs={"max_dynamic_patch": 4},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -449,15 +442,15 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -449,15 +442,15 @@ def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "\n".join(f"Image-{i}: <image>\n" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True) prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(messages, messages, tokenize=False, add_generation_prompt=True
tokenize=False, )
add_generation_prompt=True)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -509,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -509,8 +502,9 @@ def load_phi3v(question: str, image_urls: list[str]) -> ModelRequestData:
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
mm_processor_kwargs={"num_crops": 4}, mm_processor_kwargs={"num_crops": 4},
) )
placeholders = "\n".join(f"<|image_{i}|>" placeholders = "\n".join(
for i, _ in enumerate(image_urls, start=1)) f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)
)
prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n"
return ModelRequestData( return ModelRequestData(
...@@ -542,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -542,8 +536,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
mm_processor_kwargs={"dynamic_hd": 4}, mm_processor_kwargs={"dynamic_hd": 4},
) )
placeholders = "".join(f"<|image_{i}|>" placeholders = "".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1))
for i, _ in enumerate(image_urls, start=1))
prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>" prompt = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
return ModelRequestData( return ModelRequestData(
...@@ -554,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -554,8 +547,7 @@ def load_phi4mm(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_qwen_vl_chat(question: str, def load_qwen_vl_chat(question: str, image_urls: list[str]) -> ModelRequestData:
image_urls: list[str]) -> ModelRequestData:
model_name = "Qwen/Qwen-VL-Chat" model_name = "Qwen/Qwen-VL-Chat"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
...@@ -565,24 +557,26 @@ def load_qwen_vl_chat(question: str, ...@@ -565,24 +557,26 @@ def load_qwen_vl_chat(question: str,
hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]},
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "".join(f"Picture {i}: <img></img>\n" placeholders = "".join(
for i, _ in enumerate(image_urls, start=1)) f"Picture {i}: <img></img>\n" for i, _ in enumerate(image_urls, start=1)
)
# This model does not have a chat_template attribute on its tokenizer, # This model does not have a chat_template attribute on its tokenizer,
# so we need to explicitly pass it. We use ChatML since it's used in the # so we need to explicitly pass it. We use ChatML since it's used in the
# generation utils of the model: # generation utils of the model:
# https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265 # https://huggingface.co/Qwen/Qwen-VL-Chat/blob/main/qwen_generation_utils.py#L265
tokenizer = AutoTokenizer.from_pretrained(model_name, tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
trust_remote_code=True)
# Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating # Copied from: https://huggingface.co/docs/transformers/main/en/chat_templating
chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501 chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
prompt = tokenizer.apply_chat_template(messages, prompt = tokenizer.apply_chat_template(
tokenize=False, messages,
add_generation_prompt=True, tokenize=False,
chat_template=chat_template) add_generation_prompt=True,
chat_template=chat_template,
)
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"] stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
...@@ -600,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -600,9 +594,11 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try: try:
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
except ModuleNotFoundError: except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not ' print(
'be automatically resized. You can enable this functionality by ' "WARNING: `qwen-vl-utils` not installed, input images will not "
'`pip install qwen-vl-utils`.') "be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None process_vision_info = None
model_name = "Qwen/Qwen2-VL-7B-Instruct" model_name = "Qwen/Qwen2-VL-7B-Instruct"
...@@ -616,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -616,26 +612,22 @@ def load_qwen2_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {
}, { "role": "user",
"role": "content": [
"user", *placeholders,
"content": [ {"type": "text", "text": question},
*placeholders, ],
{ },
"type": "text", ]
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
if process_vision_info is None: if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls] image_data = [fetch_image(url) for url in image_urls]
...@@ -653,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -653,9 +645,11 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
try: try:
from qwen_vl_utils import process_vision_info from qwen_vl_utils import process_vision_info
except ModuleNotFoundError: except ModuleNotFoundError:
print('WARNING: `qwen-vl-utils` not installed, input images will not ' print(
'be automatically resized. You can enable this functionality by ' "WARNING: `qwen-vl-utils` not installed, input images will not "
'`pip install qwen-vl-utils`.') "be automatically resized. You can enable this functionality by "
"`pip install qwen-vl-utils`."
)
process_vision_info = None process_vision_info = None
model_name = "Qwen/Qwen2.5-VL-3B-Instruct" model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
...@@ -668,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -668,32 +662,27 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
placeholders = [{"type": "image", "image": url} for url in image_urls] placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {
}, { "role": "user",
"role": "content": [
"user", *placeholders,
"content": [ {"type": "text", "text": question},
*placeholders, ],
{ },
"type": "text", ]
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages, prompt = processor.apply_chat_template(
tokenize=False, messages, tokenize=False, add_generation_prompt=True
add_generation_prompt=True) )
if process_vision_info is None: if process_vision_info is None:
image_data = [fetch_image(url) for url in image_urls] image_data = [fetch_image(url) for url in image_urls]
else: else:
image_data, _ = process_vision_info(messages, image_data, _ = process_vision_info(messages, return_video_kwargs=False)
return_video_kwargs=False)
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
...@@ -726,23 +715,20 @@ model_example_map = { ...@@ -726,23 +715,20 @@ model_example_map = {
} }
def run_generate(model, question: str, image_urls: list[str], def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]):
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=256, temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
outputs = llm.generate( outputs = llm.generate(
{ {
"prompt": req_data.prompt, "prompt": req_data.prompt,
"multi_modal_data": { "multi_modal_data": {"image": req_data.image_data},
"image": req_data.image_data
},
}, },
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=req_data.lora_requests, lora_request=req_data.lora_requests,
...@@ -755,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str], ...@@ -755,38 +741,40 @@ def run_generate(model, question: str, image_urls: list[str],
print("-" * 50) print("-" * 50)
def run_chat(model: str, question: str, image_urls: list[str], def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]):
seed: Optional[int]):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory # Disable other modalities to save memory
default_limits = {"image": 0, "video": 0, "audio": 0} default_limits = {"image": 0, "video": 0, "audio": 0}
req_data.engine_args.limit_mm_per_prompt = default_limits | dict( req_data.engine_args.limit_mm_per_prompt = default_limits | dict(
req_data.engine_args.limit_mm_per_prompt or {}) req_data.engine_args.limit_mm_per_prompt or {}
)
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(
max_tokens=256, temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
stop_token_ids=req_data.stop_token_ids) )
outputs = llm.chat( outputs = llm.chat(
[{ [
"role": {
"user", "role": "user",
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": question, "text": question,
},
*({
"type": "image_url",
"image_url": {
"url": image_url
}, },
} for image_url in image_urls), *(
], {
}], "type": "image_url",
"image_url": {"url": image_url},
}
for image_url in image_urls
),
],
}
],
sampling_params=sampling_params, sampling_params=sampling_params,
chat_template=req_data.chat_template, chat_template=req_data.chat_template,
lora_request=req_data.lora_requests, lora_request=req_data.lora_requests,
...@@ -801,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str], ...@@ -801,32 +789,39 @@ def run_chat(model: str, question: str, image_urls: list[str],
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'vision language models that support multi-image input for text ' "vision language models that support multi-image input for text "
'generation') "generation"
parser.add_argument('--model-type', )
'-m', parser.add_argument(
type=str, "--model-type",
default="phi3_v", "-m",
choices=model_example_map.keys(), type=str,
help='Huggingface "model_type".') default="phi3_v",
parser.add_argument("--method", choices=model_example_map.keys(),
type=str, help='Huggingface "model_type".',
default="generate", )
choices=["generate", "chat"], parser.add_argument(
help="The method to run in `vllm.LLM`.") "--method",
parser.add_argument("--seed", type=str,
type=int, default="generate",
default=None, choices=["generate", "chat"],
help="Set the seed when initializing `vllm.LLM`.") help="The method to run in `vllm.LLM`.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.",
)
parser.add_argument( parser.add_argument(
"--num-images", "--num-images",
"-n", "-n",
type=int, type=int,
choices=list(range(1, choices=list(range(1, len(IMAGE_URLS) + 1)), # the max number of images
len(IMAGE_URLS) + 1)), # the max number of images
default=2, default=2,
help="Number of images to use for the demo.") help="Number of images to use for the demo.",
)
return parser.parse_args() return parser.parse_args()
...@@ -835,7 +830,7 @@ def main(args: Namespace): ...@@ -835,7 +830,7 @@ def main(args: Namespace):
method = args.method method = args.method
seed = args.seed seed = args.seed
image_urls = IMAGE_URLS[:args.num_images] image_urls = IMAGE_URLS[: args.num_images]
if method == "generate": if method == "generate":
run_generate(model, QUESTION, image_urls, seed) run_generate(model, QUESTION, image_urls, seed)
......
...@@ -17,16 +17,15 @@ import requests ...@@ -17,16 +17,15 @@ import requests
def clear_line(n: int = 1) -> None: def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A' LINE_UP = "\033[1A"
LINE_CLEAR = '\x1b[2K' LINE_CLEAR = "\x1b[2K"
for _ in range(n): for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True) print(LINE_UP, end=LINE_CLEAR, flush=True)
def post_http_request(prompt: str, def post_http_request(
api_url: str, prompt: str, api_url: str, n: int = 1, stream: bool = False
n: int = 1, ) -> requests.Response:
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"} headers = {"User-Agent": "Test Client"}
pload = { pload = {
"prompt": prompt, "prompt": prompt,
...@@ -35,17 +34,14 @@ def post_http_request(prompt: str, ...@@ -35,17 +34,14 @@ def post_http_request(prompt: str,
"max_tokens": 16, "max_tokens": 16,
"stream": stream, "stream": stream,
} }
response = requests.post(api_url, response = requests.post(api_url, headers=headers, json=pload, stream=stream)
headers=headers,
json=pload,
stream=stream)
return response return response
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]: def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
for chunk in response.iter_lines(chunk_size=8192, for chunk in response.iter_lines(
decode_unicode=False, chunk_size=8192, decode_unicode=False, delimiter=b"\n"
delimiter=b"\n"): ):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"] output = data["text"]
......
...@@ -6,6 +6,7 @@ Note that `pip install cohere` is needed to run this example. ...@@ -6,6 +6,7 @@ Note that `pip install cohere` is needed to run this example.
run: vllm serve BAAI/bge-reranker-base run: vllm serve BAAI/bge-reranker-base
""" """
from typing import Union from typing import Union
import cohere import cohere
...@@ -16,28 +17,28 @@ model = "BAAI/bge-reranker-base" ...@@ -16,28 +17,28 @@ model = "BAAI/bge-reranker-base"
query = "What is the capital of France?" query = "What is the capital of France?"
documents = [ documents = [
"The capital of France is Paris", "Reranking is fun!", "The capital of France is Paris",
"vLLM is an open-source framework for fast AI serving" "Reranking is fun!",
"vLLM is an open-source framework for fast AI serving",
] ]
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, def cohere_rerank(
documents: list[str]) -> dict: client: Union[Client, ClientV2], model: str, query: str, documents: list[str]
) -> dict:
return client.rerank(model=model, query=query, documents=documents) return client.rerank(model=model, query=query, documents=documents)
def main(): def main():
# cohere v1 client # cohere v1 client
cohere_v1 = cohere.Client(base_url="http://localhost:8000", cohere_v1 = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
api_key="sk-fake-key")
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
print("-" * 50) print("-" * 50)
print("rerank_v1_result:\n", rerank_v1_result) print("rerank_v1_result:\n", rerank_v1_result)
print("-" * 50) print("-" * 50)
# or the v2 # or the v2
cohere_v2 = cohere.ClientV2("sk-fake-key", cohere_v2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
base_url="http://localhost:8000")
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
print("rerank_v2_result:\n", rerank_v2_result) print("rerank_v2_result:\n", rerank_v2_result)
print("-" * 50) print("-" * 50)
......
...@@ -13,6 +13,7 @@ launch this proxy demo through: ...@@ -13,6 +13,7 @@ launch this proxy demo through:
Note: This demo will be removed once the PDController implemented in PR 15343 Note: This demo will be removed once the PDController implemented in PR 15343
(https://github.com/vllm-project/vllm/pull/15343) supports XpYd. (https://github.com/vllm-project/vllm/pull/15343) supports XpYd.
""" """
import argparse import argparse
import ipaddress import ipaddress
import itertools import itertools
...@@ -26,8 +27,7 @@ from typing import Callable, Optional ...@@ -26,8 +27,7 @@ from typing import Callable, Optional
import aiohttp import aiohttp
import requests import requests
import uvicorn import uvicorn
from fastapi import (APIRouter, Depends, FastAPI, Header, HTTPException, from fastapi import APIRouter, Depends, FastAPI, Header, HTTPException, Request, status
Request, status)
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
...@@ -36,24 +36,24 @@ logging.basicConfig(level=logging.INFO) ...@@ -36,24 +36,24 @@ logging.basicConfig(level=logging.INFO)
class SchedulingPolicy(ABC): class SchedulingPolicy(ABC):
@abstractmethod @abstractmethod
def schedule(self, cycler: itertools.cycle): def schedule(self, cycler: itertools.cycle):
raise NotImplementedError("Scheduling Proxy is not set.") raise NotImplementedError("Scheduling Proxy is not set.")
class Proxy: class Proxy:
def __init__( def __init__(
self, self,
prefill_instances: list[str], prefill_instances: list[str],
decode_instances: list[str], decode_instances: list[str],
model: str, model: str,
scheduling_policy: SchedulingPolicy, scheduling_policy: SchedulingPolicy,
custom_create_completion: Optional[Callable[[Request], custom_create_completion: Optional[
StreamingResponse]] = None, Callable[[Request], StreamingResponse]
custom_create_chat_completion: Optional[Callable[ ] = None,
[Request], StreamingResponse]] = None, custom_create_chat_completion: Optional[
Callable[[Request], StreamingResponse]
] = None,
): ):
self.prefill_instances = prefill_instances self.prefill_instances = prefill_instances
self.decode_instances = decode_instances self.decode_instances = decode_instances
...@@ -68,30 +68,30 @@ class Proxy: ...@@ -68,30 +68,30 @@ class Proxy:
def setup_routes(self): def setup_routes(self):
self.router.post( self.router.post(
"/v1/completions", "/v1/completions", dependencies=[Depends(self.validate_json_request)]
dependencies=[ )(
Depends(self.validate_json_request) self.custom_create_completion
])(self.custom_create_completion if self. if self.custom_create_completion
custom_create_completion else self.create_completion) else self.create_completion
)
self.router.post( self.router.post(
"/v1/chat/completions", "/v1/chat/completions", dependencies=[Depends(self.validate_json_request)]
dependencies=[ )(
Depends(self.validate_json_request) self.custom_create_chat_completion
])(self.custom_create_chat_completion if self. if self.custom_create_chat_completion
custom_create_chat_completion else self.create_chat_completion) else self.create_chat_completion
self.router.get("/status", )
response_class=JSONResponse)(self.get_status) self.router.get("/status", response_class=JSONResponse)(self.get_status)
self.router.post("/instances/add", self.router.post(
dependencies=[Depends(self.api_key_authenticate) "/instances/add", dependencies=[Depends(self.api_key_authenticate)]
])(self.add_instance_endpoint) )(self.add_instance_endpoint)
async def validate_json_request(self, raw_request: Request): async def validate_json_request(self, raw_request: Request):
content_type = raw_request.headers.get("content-type", "").lower() content_type = raw_request.headers.get("content-type", "").lower()
if content_type != "application/json": if content_type != "application/json":
raise HTTPException( raise HTTPException(
status_code=415, status_code=415,
detail= detail="Unsupported Media Type: Only 'application/json' is allowed",
"Unsupported Media Type: Only 'application/json' is allowed",
) )
def api_key_authenticate(self, x_api_key: str = Header(...)): def api_key_authenticate(self, x_api_key: str = Header(...)):
...@@ -103,8 +103,7 @@ class Proxy: ...@@ -103,8 +103,7 @@ class Proxy:
detail="Server configuration error.", detail="Server configuration error.",
) )
if x_api_key != expected_api_key: if x_api_key != expected_api_key:
logger.warning("Unauthorized access attempt with API Key: %s", logger.warning("Unauthorized access attempt with API Key: %s", x_api_key)
x_api_key)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Forbidden: Invalid API Key.", detail="Forbidden: Invalid API Key.",
...@@ -113,8 +112,7 @@ class Proxy: ...@@ -113,8 +112,7 @@ class Proxy:
async def validate_instance(self, instance: str) -> bool: async def validate_instance(self, instance: str) -> bool:
url = f"http://{instance}/v1/models" url = f"http://{instance}/v1/models"
try: try:
async with aiohttp.ClientSession( async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as client:
timeout=AIOHTTP_TIMEOUT) as client:
logger.info("Verifying %s ...", instance) logger.info("Verifying %s ...", instance)
async with client.get(url) as response: async with client.get(url) as response:
if response.status == 200: if response.status == 200:
...@@ -122,12 +120,15 @@ class Proxy: ...@@ -122,12 +120,15 @@ class Proxy:
if "data" in data and len(data["data"]) > 0: if "data" in data and len(data["data"]) > 0:
model_cur = data["data"][0].get("id", "") model_cur = data["data"][0].get("id", "")
if model_cur == self.model: if model_cur == self.model:
logger.info("Instance: %s could be added.", logger.info("Instance: %s could be added.", instance)
instance)
return True return True
else: else:
logger.warning("Mismatch model %s : %s != %s", logger.warning(
instance, model_cur, self.model) "Mismatch model %s : %s != %s",
instance,
model_cur,
self.model,
)
return False return False
else: else:
return False return False
...@@ -147,48 +148,47 @@ class Proxy: ...@@ -147,48 +148,47 @@ class Proxy:
instance_type = data.get("type") instance_type = data.get("type")
instance = data.get("instance") instance = data.get("instance")
if instance_type not in ["prefill", "decode"]: if instance_type not in ["prefill", "decode"]:
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid instance type.")
detail="Invalid instance type.")
if not instance or ":" not in instance: if not instance or ":" not in instance:
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid instance format.")
detail="Invalid instance format.")
host, port_str = instance.split(":") host, port_str = instance.split(":")
try: try:
if host != "localhost": if host != "localhost":
ipaddress.ip_address(host) ipaddress.ip_address(host)
port = int(port_str) port = int(port_str)
if not (0 < port < 65536): if not (0 < port < 65536):
raise HTTPException(status_code=400, raise HTTPException(status_code=400, detail="Invalid port number.")
detail="Invalid port number.")
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, raise HTTPException(
detail="Invalid instance address.") from e status_code=400, detail="Invalid instance address."
) from e
is_valid = await self.validate_instance(instance) is_valid = await self.validate_instance(instance)
if not is_valid: if not is_valid:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance validation failed.") status_code=400, detail="Instance validation failed."
)
if instance_type == "prefill": if instance_type == "prefill":
if instance not in self.prefill_instances: if instance not in self.prefill_instances:
self.prefill_instances.append(instance) self.prefill_instances.append(instance)
self.prefill_cycler = itertools.cycle( self.prefill_cycler = itertools.cycle(self.prefill_instances)
self.prefill_instances)
else: else:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance already exists.") status_code=400, detail="Instance already exists."
)
else: else:
if instance not in self.decode_instances: if instance not in self.decode_instances:
self.decode_instances.append(instance) self.decode_instances.append(instance)
self.decode_cycler = itertools.cycle(self.decode_instances) self.decode_cycler = itertools.cycle(self.decode_instances)
else: else:
raise HTTPException(status_code=400, raise HTTPException(
detail="Instance already exists.") status_code=400, detail="Instance already exists."
)
return JSONResponse(content={ return JSONResponse(
"message": content={"message": f"Added {instance} to {instance_type}_instances."}
f"Added {instance} to {instance_type}_instances." )
})
except HTTPException as http_exc: except HTTPException as http_exc:
raise http_exc raise http_exc
except Exception as e: except Exception as e:
...@@ -197,16 +197,16 @@ class Proxy: ...@@ -197,16 +197,16 @@ class Proxy:
async def forward_request(self, url, data, use_chunked=True): async def forward_request(self, url, data, use_chunked=True):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = { headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
}
try: try:
async with session.post(url=url, json=data, async with session.post(
headers=headers) as response: url=url, json=data, headers=headers
) as response:
if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501 if 200 <= response.status < 300 or 400 <= response.status < 500: # noqa: E501
if use_chunked: if use_chunked:
async for chunk_bytes in response.content.iter_chunked( # noqa: E501 async for chunk_bytes in response.content.iter_chunked( # noqa: E501
1024): 1024
):
yield chunk_bytes yield chunk_bytes
else: else:
content = await response.read() content = await response.read()
...@@ -217,20 +217,21 @@ class Proxy: ...@@ -217,20 +217,21 @@ class Proxy:
error_content = json.loads(error_content) error_content = json.loads(error_content)
except json.JSONDecodeError: except json.JSONDecodeError:
error_content = error_content error_content = error_content
logger.error("Request failed with status %s: %s", logger.error(
response.status, error_content) "Request failed with status %s: %s",
response.status,
error_content,
)
raise HTTPException( raise HTTPException(
status_code=response.status, status_code=response.status,
detail= detail=f"Request failed with status {response.status}: "
f"Request failed with status {response.status}: "
f"{error_content}", f"{error_content}",
) )
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.error("ClientError occurred: %s", str(e)) logger.error("ClientError occurred: %s", str(e))
raise HTTPException( raise HTTPException(
status_code=502, status_code=502,
detail= detail="Bad Gateway: Error communicating with upstream server.",
"Bad Gateway: Error communicating with upstream server.",
) from e ) from e
except Exception as e: except Exception as e:
logger.error("Unexpected error: %s", str(e)) logger.error("Unexpected error: %s", str(e))
...@@ -258,8 +259,8 @@ class Proxy: ...@@ -258,8 +259,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler) prefill_instance = self.schedule(self.prefill_cycler)
try: try:
async for _ in self.forward_request( async for _ in self.forward_request(
f"http://{prefill_instance}/v1/completions", f"http://{prefill_instance}/v1/completions", kv_prepare_request
kv_prepare_request): ):
continue continue
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance) self.remove_instance_endpoint("prefill", prefill_instance)
...@@ -270,7 +271,8 @@ class Proxy: ...@@ -270,7 +271,8 @@ class Proxy:
try: try:
generator = self.forward_request( generator = self.forward_request(
f"http://{decode_instance}/v1/completions", request) f"http://{decode_instance}/v1/completions", request
)
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance) self.remove_instance_endpoint("decode", decode_instance)
raise http_exc raise http_exc
...@@ -295,8 +297,8 @@ class Proxy: ...@@ -295,8 +297,8 @@ class Proxy:
prefill_instance = self.schedule(self.prefill_cycler) prefill_instance = self.schedule(self.prefill_cycler)
try: try:
async for _ in self.forward_request( async for _ in self.forward_request(
f"http://{prefill_instance}/v1/chat/completions", f"http://{prefill_instance}/v1/chat/completions", kv_prepare_request
kv_prepare_request): ):
continue continue
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("prefill", prefill_instance) self.remove_instance_endpoint("prefill", prefill_instance)
...@@ -306,8 +308,8 @@ class Proxy: ...@@ -306,8 +308,8 @@ class Proxy:
try: try:
generator = self.forward_request( generator = self.forward_request(
"http://" + decode_instance + "/v1/chat/completions", "http://" + decode_instance + "/v1/chat/completions", request
request) )
except HTTPException as http_exc: except HTTPException as http_exc:
self.remove_instance_endpoint("decode", decode_instance) self.remove_instance_endpoint("decode", decode_instance)
raise http_exc raise http_exc
...@@ -318,20 +320,20 @@ class Proxy: ...@@ -318,20 +320,20 @@ class Proxy:
error_messages = [str(e) for e in exc_info if e] error_messages = [str(e) for e in exc_info if e]
print("Error occurred in disagg proxy server") print("Error occurred in disagg proxy server")
print(error_messages) print(error_messages)
return StreamingResponse(content=iter(error_messages), return StreamingResponse(
media_type="text/event-stream") content=iter(error_messages), media_type="text/event-stream"
)
def remove_instance_endpoint(self, instance_type, instance): def remove_instance_endpoint(self, instance_type, instance):
if (instance_type == "decode" and instance in self.decode_instances): if instance_type == "decode" and instance in self.decode_instances:
self.decode_instances.remove(instance) self.decode_instances.remove(instance)
self.decode_cycler = itertools.cycle(self.decode_instances) self.decode_cycler = itertools.cycle(self.decode_instances)
if (instance_type == "prefill" and instance in self.decode_instances): if instance_type == "prefill" and instance in self.decode_instances:
self.prefill_instances.remove(instance) self.prefill_instances.remove(instance)
self.prefill_cycler = itertools.cycle(self.decode_instances) self.prefill_cycler = itertools.cycle(self.decode_instances)
class RoundRobinSchedulingPolicy(SchedulingPolicy): class RoundRobinSchedulingPolicy(SchedulingPolicy):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -340,15 +342,12 @@ class RoundRobinSchedulingPolicy(SchedulingPolicy): ...@@ -340,15 +342,12 @@ class RoundRobinSchedulingPolicy(SchedulingPolicy):
class ProxyServer: class ProxyServer:
def __init__( def __init__(
self, self,
args: argparse.Namespace, args: argparse.Namespace,
scheduling_policy: Optional[SchedulingPolicy] = None, scheduling_policy: Optional[SchedulingPolicy] = None,
create_completion: Optional[Callable[[Request], create_completion: Optional[Callable[[Request], StreamingResponse]] = None,
StreamingResponse]] = None, create_chat_completion: Optional[Callable[[Request], StreamingResponse]] = None,
create_chat_completion: Optional[Callable[[Request],
StreamingResponse]] = None,
): ):
self.validate_parsed_serve_args(args) self.validate_parsed_serve_args(args)
self.port = args.port self.port = args.port
...@@ -356,8 +355,11 @@ class ProxyServer: ...@@ -356,8 +355,11 @@ class ProxyServer:
prefill_instances=[] if args.prefill is None else args.prefill, prefill_instances=[] if args.prefill is None else args.prefill,
decode_instances=[] if args.decode is None else args.decode, decode_instances=[] if args.decode is None else args.decode,
model=args.model, model=args.model,
scheduling_policy=(scheduling_policy if scheduling_policy scheduling_policy=(
is not None else RoundRobinSchedulingPolicy()), scheduling_policy
if scheduling_policy is not None
else RoundRobinSchedulingPolicy()
),
custom_create_completion=create_completion, custom_create_completion=create_completion,
custom_create_chat_completion=create_chat_completion, custom_create_chat_completion=create_chat_completion,
) )
...@@ -382,11 +384,9 @@ class ProxyServer: ...@@ -382,11 +384,9 @@ class ProxyServer:
ipaddress.ip_address(host) ipaddress.ip_address(host)
port = int(port) port = int(port)
if not (0 < port < 65536): if not (0 < port < 65536):
raise ValueError( raise ValueError(f"Invalid port number in instance: {instance}")
f"Invalid port number in instance: {instance}")
except Exception as e: except Exception as e:
raise ValueError( raise ValueError(f"Invalid instance {instance}: {str(e)}") from e
f"Invalid instance {instance}: {str(e)}") from e
def verify_model_config(self, instances: list, model: str) -> None: def verify_model_config(self, instances: list, model: str) -> None:
model_suffix = model.split("/")[-1] model_suffix = model.split("/")[-1]
...@@ -399,12 +399,14 @@ class ProxyServer: ...@@ -399,12 +399,14 @@ class ProxyServer:
if model_cur_suffix != model_suffix: if model_cur_suffix != model_suffix:
raise ValueError( raise ValueError(
f"{instance} serves a different model: " f"{instance} serves a different model: "
f"{model_cur} != {model}") f"{model_cur} != {model}"
)
else: else:
raise ValueError(f"Cannot get model id from {instance}!") raise ValueError(f"Cannot get model id from {instance}!")
except requests.RequestException as e: except requests.RequestException as e:
raise ValueError( raise ValueError(
f"Error communicating with {instance}: {str(e)}") from e f"Error communicating with {instance}: {str(e)}"
) from e
def run_server(self): def run_server(self):
app = FastAPI() app = FastAPI()
...@@ -417,11 +419,7 @@ class ProxyServer: ...@@ -417,11 +419,7 @@ class ProxyServer:
def parse_args(): def parse_args():
# Todo: allow more config # Todo: allow more config
parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") parser = argparse.ArgumentParser("vLLM disaggregated proxy server.")
parser.add_argument("--model", parser.add_argument("--model", "-m", type=str, required=True, help="Model name")
"-m",
type=str,
required=True,
help="Model name")
parser.add_argument( parser.add_argument(
"--prefill", "--prefill",
......
...@@ -17,6 +17,7 @@ you can install it manually by following these steps: ...@@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
""" """
import argparse import argparse
import gradio as gr import gradio as gr
...@@ -24,16 +25,12 @@ from openai import OpenAI ...@@ -24,16 +25,12 @@ from openai import OpenAI
def format_history_to_openai(history): def format_history_to_openai(history):
history_openai_format = [{ history_openai_format = [
"role": "system", {"role": "system", "content": "You are a great AI assistant."}
"content": "You are a great AI assistant." ]
}]
for human, assistant in history: for human, assistant in history:
history_openai_format.append({"role": "user", "content": human}) history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({ history_openai_format.append({"role": "assistant", "content": assistant})
"role": "assistant",
"content": assistant
})
return history_openai_format return history_openai_format
...@@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids): ...@@ -49,17 +46,17 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
temperature=temp, temperature=temp,
stream=True, stream=True,
extra_body={ extra_body={
'repetition_penalty': "repetition_penalty": 1,
1, "stop_token_ids": [int(id.strip()) for id in stop_token_ids.split(",")]
'stop_token_ids': if stop_token_ids
[int(id.strip()) else [],
for id in stop_token_ids.split(',')] if stop_token_ids else [] },
}) )
# Collect all chunks and concatenate them into a full message # Collect all chunks and concatenate them into a full message
full_message = "" full_message = ""
for chunk in stream: for chunk in stream:
full_message += (chunk.choices[0].delta.content or "") full_message += chunk.choices[0].delta.content or ""
# Return the full message as a single response # Return the full message as a single response
return full_message return full_message
...@@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids): ...@@ -67,38 +64,34 @@ def predict(message, history, client, model_name, temp, stop_token_ids):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Chatbot Interface with Customizable Parameters') description="Chatbot Interface with Customizable Parameters"
parser.add_argument('--model-url', )
type=str, parser.add_argument(
default='http://localhost:8000/v1', "--model-url", type=str, default="http://localhost:8000/v1", help="Model URL"
help='Model URL') )
parser.add_argument('-m', parser.add_argument(
'--model', "-m", "--model", type=str, required=True, help="Model name for the chatbot"
type=str, )
required=True, parser.add_argument(
help='Model name for the chatbot') "--temp", type=float, default=0.8, help="Temperature for text generation"
parser.add_argument('--temp', )
type=float, parser.add_argument(
default=0.8, "--stop-token-ids", type=str, default="", help="Comma-separated stop token IDs"
help='Temperature for text generation') )
parser.add_argument('--stop-token-ids',
type=str,
default='',
help='Comma-separated stop token IDs')
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
return parser.parse_args() return parser.parse_args()
def build_gradio_interface(client, model_name, temp, stop_token_ids): def build_gradio_interface(client, model_name, temp, stop_token_ids):
def chat_predict(message, history): def chat_predict(message, history):
return predict(message, history, client, model_name, temp, return predict(message, history, client, model_name, temp, stop_token_ids)
stop_token_ids)
return gr.ChatInterface(fn=chat_predict, return gr.ChatInterface(
title="Chatbot Interface", fn=chat_predict,
description="A simple chatbot powered by vLLM") title="Chatbot Interface",
description="A simple chatbot powered by vLLM",
)
def main(): def main():
...@@ -113,12 +106,13 @@ def main(): ...@@ -113,12 +106,13 @@ def main():
client = OpenAI(api_key=openai_api_key, base_url=openai_api_base) client = OpenAI(api_key=openai_api_key, base_url=openai_api_base)
# Define the Gradio chatbot interface using the predict function # Define the Gradio chatbot interface using the predict function
gradio_interface = build_gradio_interface(client, args.model, args.temp, gradio_interface = build_gradio_interface(
args.stop_token_ids) client, args.model, args.temp, args.stop_token_ids
)
gradio_interface.queue().launch(server_name=args.host, gradio_interface.queue().launch(
server_port=args.port, server_name=args.host, server_port=args.port, share=True
share=True) )
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -17,6 +17,7 @@ you can install it manually by following these steps: ...@@ -17,6 +17,7 @@ you can install it manually by following these steps:
2. Rename the downloaded file to: frpc_linux_amd64_v0.3 2. Rename the downloaded file to: frpc_linux_amd64_v0.3
3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc 3. Move the file to this location: /home/user/.cache/huggingface/gradio/frpc
""" """
import argparse import argparse
import json import json
...@@ -31,14 +32,11 @@ def http_bot(prompt): ...@@ -31,14 +32,11 @@ def http_bot(prompt):
"stream": True, "stream": True,
"max_tokens": 128, "max_tokens": 128,
} }
response = requests.post(args.model_url, response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
headers=headers,
json=pload, for chunk in response.iter_lines(
stream=True) chunk_size=8192, decode_unicode=False, delimiter=b"\n"
):
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\n"):
if chunk: if chunk:
data = json.loads(chunk.decode("utf-8")) data = json.loads(chunk.decode("utf-8"))
output = data["text"][0] output = data["text"][0]
...@@ -48,10 +46,10 @@ def http_bot(prompt): ...@@ -48,10 +46,10 @@ def http_bot(prompt):
def build_demo(): def build_demo():
with gr.Blocks() as demo: with gr.Blocks() as demo:
gr.Markdown("# vLLM text completion demo\n") gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input", inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
placeholder="Enter text and press ENTER") outputbox = gr.Textbox(
outputbox = gr.Textbox(label="Output", label="Output", placeholder="Generated result from the model"
placeholder="Generated result from the model") )
inputbox.submit(http_bot, [inputbox], [outputbox]) inputbox.submit(http_bot, [inputbox], [outputbox])
return demo return demo
...@@ -60,17 +58,15 @@ def parse_args(): ...@@ -60,17 +58,15 @@ def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None) parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8001) parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", parser.add_argument(
type=str, "--model-url", type=str, default="http://localhost:8000/generate"
default="http://localhost:8000/generate") )
return parser.parse_args() return parser.parse_args()
def main(args): def main(args):
demo = build_demo() demo = build_demo()
demo.queue().launch(server_name=args.host, demo.queue().launch(server_name=args.host, server_port=args.port, share=True)
server_port=args.port,
share=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -5,6 +5,7 @@ Jina and Cohere https://jina.ai/reranker ...@@ -5,6 +5,7 @@ Jina and Cohere https://jina.ai/reranker
run: vllm serve BAAI/bge-reranker-base run: vllm serve BAAI/bge-reranker-base
""" """
import json import json
import requests import requests
...@@ -14,14 +15,13 @@ url = "http://127.0.0.1:8000/rerank" ...@@ -14,14 +15,13 @@ url = "http://127.0.0.1:8000/rerank"
headers = {"accept": "application/json", "Content-Type": "application/json"} headers = {"accept": "application/json", "Content-Type": "application/json"}
data = { data = {
"model": "model": "BAAI/bge-reranker-base",
"BAAI/bge-reranker-base", "query": "What is the capital of France?",
"query":
"What is the capital of France?",
"documents": [ "documents": [
"The capital of Brazil is Brasilia.", "The capital of Brazil is Brasilia.",
"The capital of France is Paris.", "Horses and cows are both animals" "The capital of France is Paris.",
] "Horses and cows are both animals",
],
} }
......
...@@ -9,17 +9,14 @@ from msgspec.msgpack import Decoder ...@@ -9,17 +9,14 @@ from msgspec.msgpack import Decoder
# #
# Types copied from vllm.distributed.kv_events # Types copied from vllm.distributed.kv_events
# #
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
gc=False):
ts: float ts: float
events: list[Any] events: list[Any]
class KVCacheEvent(msgspec.Struct, class KVCacheEvent(
array_like=True, msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
omit_defaults=True, ):
gc=False,
tag=True):
"""Base class for all KV cache-related events""" """Base class for all KV cache-related events"""
...@@ -77,8 +74,9 @@ def main(): ...@@ -77,8 +74,9 @@ def main():
if last_seq >= 0 and seq > last_seq + 1: if last_seq >= 0 and seq > last_seq + 1:
missed = seq - last_seq - 1 missed = seq - last_seq - 1
print(f"Missed {missed} messages" print(
f" (last: {last_seq}, current: {seq})") f"Missed {missed} messages (last: {last_seq}, current: {seq})"
)
replay.send((last_seq + 1).to_bytes(8, "big")) replay.send((last_seq + 1).to_bytes(8, "big"))
......
...@@ -12,26 +12,22 @@ from openai import OpenAI ...@@ -12,26 +12,22 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
messages = [{ messages = [
"role": "system", {"role": "system", "content": "You are a helpful assistant."},
"content": "You are a helpful assistant." {"role": "user", "content": "Who won the world series in 2020?"},
}, { {
"role": "user", "role": "assistant",
"content": "Who won the world series in 2020?" "content": "The Los Angeles Dodgers won the World Series in 2020.",
}, { },
"role": "assistant", {"role": "user", "content": "Where was it played?"},
"content": "The Los Angeles Dodgers won the World Series in 2020." ]
}, {
"role": "user",
"content": "Where was it played?"
}]
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Client for vLLM API server") parser = argparse.ArgumentParser(description="Client for vLLM API server")
parser.add_argument("--stream", parser.add_argument(
action="store_true", "--stream", action="store_true", help="Enable streaming response"
help="Enable streaming response") )
return parser.parse_args() return parser.parse_args()
......
...@@ -43,7 +43,7 @@ def encode_base64_content_from_url(content_url: str) -> str: ...@@ -43,7 +43,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
with requests.get(content_url) as response: with requests.get(content_url) as response:
response.raise_for_status() response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8') result = base64.b64encode(response.content).decode("utf-8")
return result return result
...@@ -51,10 +51,7 @@ def encode_base64_content_from_url(content_url: str) -> str: ...@@ -51,10 +51,7 @@ def encode_base64_content_from_url(content_url: str) -> str:
# Text-only inference # Text-only inference
def run_text_only(model: str) -> None: def run_text_only(model: str) -> None:
chat_completion = client.chat.completions.create( chat_completion = client.chat.completions.create(
messages=[{ messages=[{"role": "user", "content": "What's the capital of France?"}],
"role": "user",
"content": "What's the capital of France?"
}],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -65,26 +62,21 @@ def run_text_only(model: str) -> None: ...@@ -65,26 +62,21 @@ def run_text_only(model: str) -> None:
# Single-image input inference # Single-image input inference
def run_single_image(model: str) -> None: def run_single_image(model: str) -> None:
## Use image url in the payload ## Use image url in the payload
image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this image?"},
"type": "text", {
"text": "What's in this image?" "type": "image_url",
}, "image_url": {"url": image_url},
{
"type": "image_url",
"image_url": {
"url": image_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -95,22 +87,18 @@ def run_single_image(model: str) -> None: ...@@ -95,22 +87,18 @@ def run_single_image(model: str) -> None:
## Use base64 encoded image in the payload ## Use base64 encoded image in the payload
image_base64 = encode_base64_content_from_url(image_url) image_base64 = encode_base64_content_from_url(image_url)
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this image?"},
"type": "text", {
"text": "What's in this image?" "type": "image_url",
}, "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -124,28 +112,22 @@ def run_multi_image(model: str) -> None: ...@@ -124,28 +112,22 @@ def run_multi_image(model: str) -> None:
image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg"
image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg"
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What are the animals in these images?"},
"type": "text", {
"text": "What are the animals in these images?" "type": "image_url",
}, "image_url": {"url": image_url_duck},
{
"type": "image_url",
"image_url": {
"url": image_url_duck
}, },
}, {
{ "type": "image_url",
"type": "image_url", "image_url": {"url": image_url_lion},
"image_url": {
"url": image_url_lion
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -161,22 +143,18 @@ def run_video(model: str) -> None: ...@@ -161,22 +143,18 @@ def run_video(model: str) -> None:
## Use video url in the payload ## Use video url in the payload
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this video?"},
"type": "text", {
"text": "What's in this video?" "type": "video_url",
}, "video_url": {"url": video_url},
{
"type": "video_url",
"video_url": {
"url": video_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -186,22 +164,18 @@ def run_video(model: str) -> None: ...@@ -186,22 +164,18 @@ def run_video(model: str) -> None:
## Use base64 encoded video in the payload ## Use base64 encoded video in the payload
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this video?"},
"type": "text", {
"text": "What's in this video?" "type": "video_url",
}, "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
{
"type": "video_url",
"video_url": {
"url": f"data:video/mp4;base64,{video_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -219,24 +193,22 @@ def run_audio(model: str) -> None: ...@@ -219,24 +193,22 @@ def run_audio(model: str) -> None:
# OpenAI-compatible schema (`input_audio`) # OpenAI-compatible schema (`input_audio`)
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "input_audio",
}, "input_audio": {
{ # Any format supported by librosa is supported
"type": "input_audio", "data": audio_base64,
"input_audio": { "format": "wav",
# Any format supported by librosa is supported },
"data": audio_base64,
"format": "wav"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -246,23 +218,21 @@ def run_audio(model: str) -> None: ...@@ -246,23 +218,21 @@ def run_audio(model: str) -> None:
# HTTP URL # HTTP URL
chat_completion_from_url = client.chat.completions.create( chat_completion_from_url = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "audio_url",
}, "audio_url": {
{ # Any format supported by librosa is supported
"type": "audio_url", "url": audio_url
"audio_url": { },
# Any format supported by librosa is supported
"url": audio_url
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -272,23 +242,21 @@ def run_audio(model: str) -> None: ...@@ -272,23 +242,21 @@ def run_audio(model: str) -> None:
# base64 URL # base64 URL
chat_completion_from_base64 = client.chat.completions.create( chat_completion_from_base64 = client.chat.completions.create(
messages=[{ messages=[
"role": {
"user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": "What's in this audio?"},
"type": "text", {
"text": "What's in this audio?" "type": "audio_url",
}, "audio_url": {
{ # Any format supported by librosa is supported
"type": "audio_url", "url": f"data:audio/ogg;base64,{audio_base64}"
"audio_url": { },
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
}, },
}, ],
], }
}], ],
model=model, model=model,
max_completion_tokens=64, max_completion_tokens=64,
) )
...@@ -308,14 +276,17 @@ example_function_map = { ...@@ -308,14 +276,17 @@ example_function_map = {
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using OpenAI client for online serving with ' description="Demo on using OpenAI client for online serving with "
'multimodal language models served with vLLM.') "multimodal language models served with vLLM."
parser.add_argument('--chat-type', )
'-c', parser.add_argument(
type=str, "--chat-type",
default="single-image", "-c",
choices=list(example_function_map.keys()), type=str,
help='Conversation type with multimodal data.') default="single-image",
choices=list(example_function_map.keys()),
help="Conversation type with multimodal data.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -16,6 +16,7 @@ vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \ ...@@ -16,6 +16,7 @@ vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \
--chat-template examples/tool_chat_template_hermes.jinja \ --chat-template examples/tool_chat_template_hermes.jinja \
--enable-auto-tool-choice --tool-call-parser hermes --enable-auto-tool-choice --tool-call-parser hermes
""" """
import json import json
from typing import Any from typing import Any
...@@ -25,55 +26,55 @@ from openai import OpenAI ...@@ -25,55 +26,55 @@ from openai import OpenAI
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
tools = [{ properties = {
"type": "function", "city": {
"function": { "type": "string",
"name": "get_current_weather", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description": "Get the current weather in a given location", },
"parameters": { "state": {
"type": "object", "type": "string",
"properties": { "description": "the two-letter abbreviation for the state that the city is"
"city": { " in, e.g. 'CA' which would mean 'California'",
"type": },
"string", "unit": {
"description": "type": "string",
"The city to find the weather for, e.g. 'San Francisco'" "description": "The unit to fetch the temperature in",
}, "enum": ["celsius", "fahrenheit"],
"state": { },
"type": }
"string",
"description": tools = [
"the two-letter abbreviation for the state that the city is" {
" in, e.g. 'CA' which would mean 'California'" "type": "function",
}, "function": {
"unit": { "name": "get_current_weather",
"type": "string", "description": "Get the current weather in a given location",
"description": "The unit to fetch the temperature in", "parameters": {
"enum": ["celsius", "fahrenheit"] "type": "object",
} "properties": properties,
"required": ["city", "state", "unit"],
}, },
"required": ["city", "state", "unit"] },
}
} }
}] ]
messages = [{ messages = [
"role": "user", {"role": "user", "content": "Hi! How are you doing today?"},
"content": "Hi! How are you doing today?" {"role": "assistant", "content": "I'm doing well! How can I help you?"},
}, { {
"role": "assistant", "role": "user",
"content": "I'm doing well! How can I help you?" "content": (
}, { "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
"role": ),
"user", },
"content": ]
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
def get_current_weather(city: str, state: str, unit: "str"):
return (
def get_current_weather(city: str, state: str, unit: 'str'): "The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " "partly cloudly, with highs in the 90's."
"partly cloudly, with highs in the 90's.") )
def handle_tool_calls_stream( def handle_tool_calls_stream(
...@@ -82,10 +83,9 @@ def handle_tool_calls_stream( ...@@ -82,10 +83,9 @@ def handle_tool_calls_stream(
model: str, model: str,
tools: list[dict[str, Any]], tools: list[dict[str, Any]],
) -> list[Any]: ) -> list[Any]:
tool_calls_stream = client.chat.completions.create(messages=messages, tool_calls_stream = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools, stream=True
tools=tools, )
stream=True)
chunks = [] chunks = []
print("chunks: ") print("chunks: ")
for chunk in tool_calls_stream: for chunk in tool_calls_stream:
...@@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: ...@@ -106,8 +106,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
tool_call = chunk.choices[0].delta.tool_calls[0] tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.index != tool_call_idx: if tool_call.index != tool_call_idx:
if tool_call_idx >= 0: if tool_call_idx >= 0:
print(f"streamed tool call arguments: " print(f"streamed tool call arguments: {arguments[tool_call_idx]}")
f"{arguments[tool_call_idx]}")
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
arguments.append("") arguments.append("")
if tool_call.id: if tool_call.id:
...@@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]: ...@@ -115,8 +114,7 @@ def handle_tool_calls_arguments(chunks: list[Any]) -> list[str]:
if tool_call.function: if tool_call.function:
if tool_call.function.name: if tool_call.function.name:
print( print(f"streamed tool call name: {tool_call.function.name}")
f"streamed tool call name: {tool_call.function.name}")
if tool_call.function.arguments: if tool_call.function.arguments:
arguments[tool_call_idx] += tool_call.function.arguments arguments[tool_call_idx] += tool_call.function.arguments
...@@ -136,9 +134,9 @@ def main(): ...@@ -136,9 +134,9 @@ def main():
models = client.models.list() models = client.models.list()
model = models.data[0].id model = models.data[0].id
chat_completion = client.chat.completions.create(messages=messages, chat_completion = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools
tools=tools) )
print("-" * 70) print("-" * 70)
print("Chat completion results:") print("Chat completion results:")
...@@ -158,10 +156,12 @@ def main(): ...@@ -158,10 +156,12 @@ def main():
print("-" * 70) print("-" * 70)
# Add tool call results to the conversation # Add tool call results to the conversation
messages.append({ messages.append(
"role": "assistant", {
"tool_calls": chat_completion.choices[0].message.tool_calls "role": "assistant",
}) "tool_calls": chat_completion.choices[0].message.tool_calls,
}
)
# Now, simulate a tool call # Now, simulate a tool call
available_tools = {"get_current_weather": get_current_weather} available_tools = {"get_current_weather": get_current_weather}
...@@ -172,17 +172,18 @@ def main(): ...@@ -172,17 +172,18 @@ def main():
args = json.loads(call.function.arguments) args = json.loads(call.function.arguments)
result = tool_to_call(**args) result = tool_to_call(**args)
print("tool_to_call result: ", result) print("tool_to_call result: ", result)
messages.append({ messages.append(
"role": "tool", {
"content": result, "role": "tool",
"tool_call_id": call.id, "content": result,
"name": call.function.name "tool_call_id": call.id,
}) "name": call.function.name,
}
chat_completion_2 = client.chat.completions.create(messages=messages, )
model=model,
tools=tools, chat_completion_2 = client.chat.completions.create(
stream=False) messages=messages, model=model, tools=tools, stream=False
)
print("Chat completion2 results:") print("Chat completion2 results:")
print(chat_completion_2) print(chat_completion_2)
print("-" * 70) print("-" * 70)
......
...@@ -28,18 +28,16 @@ tools = [ ...@@ -28,18 +28,16 @@ tools = [
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": "The city to find the weather for"
"description":
"The city to find the weather for"
", e.g. 'San Francisco'", ", e.g. 'San Francisco'",
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": (
"description": "the two-letter abbreviation for the state that the "
"the two-letter abbreviation for the state that the " "city is in, e.g. 'CA' which would mean 'California'"
"city is in, e.g. 'CA' which would mean 'California'", ),
}, },
"unit": { "unit": {
"type": "string", "type": "string",
...@@ -60,22 +58,20 @@ tools = [ ...@@ -60,22 +58,20 @@ tools = [
"type": "object", "type": "object",
"properties": { "properties": {
"city": { "city": {
"type": "type": "string",
"string", "description": (
"description": "The city to get the forecast for, e.g. 'New York'"
"The city to get the forecast for, e.g. 'New York'", ),
}, },
"state": { "state": {
"type": "type": "string",
"string", "description": (
"description": "The two-letter abbreviation for the state, e.g. 'NY'"
"The two-letter abbreviation for the state, e.g. 'NY'", ),
}, },
"days": { "days": {
"type": "type": "integer",
"integer", "description": "Number of days to get the forecast for (1-7)",
"description":
"Number of days to get the forecast for (1-7)",
}, },
"unit": { "unit": {
"type": "string", "type": "string",
...@@ -90,19 +86,11 @@ tools = [ ...@@ -90,19 +86,11 @@ tools = [
] ]
messages = [ messages = [
{"role": "user", "content": "Hi! How are you doing today?"},
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
{ {
"role": "user", "role": "user",
"content": "Hi! How are you doing today?" "content": "Can you tell me what the current weather is in Dallas \
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Dallas \
and the forecast for the next 5 days, in fahrenheit?", and the forecast for the next 5 days, in fahrenheit?",
}, },
] ]
...@@ -123,17 +111,16 @@ def main(): ...@@ -123,17 +111,16 @@ def main():
model=model, model=model,
tools=tools, tools=tools,
tool_choice="required", tool_choice="required",
stream=True # Enable streaming response stream=True, # Enable streaming response
) )
for chunk in chat_completion: for chunk in chat_completion:
if chunk.choices and chunk.choices[0].delta.tool_calls: if chunk.choices and chunk.choices[0].delta.tool_calls:
print(chunk.choices[0].delta.tool_calls) print(chunk.choices[0].delta.tool_calls)
chat_completion = client.chat.completions.create(messages=messages, chat_completion = client.chat.completions.create(
model=model, messages=messages, model=model, tools=tools, tool_choice="required"
tools=tools, )
tool_choice="required")
print(chat_completion.choices[0].message.tool_calls) print(chat_completion.choices[0].message.tool_calls)
......
...@@ -20,10 +20,9 @@ openai_api_base = "http://localhost:8000/v1" ...@@ -20,10 +20,9 @@ openai_api_base = "http://localhost:8000/v1"
def guided_choice_completion(client: OpenAI, model: str): def guided_choice_completion(client: OpenAI, model: str):
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
"content": "Classify this sentiment: vLLM is wonderful!" ],
}],
extra_body={"guided_choice": ["positive", "negative"]}, extra_body={"guided_choice": ["positive", "negative"]},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -31,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str): ...@@ -31,20 +30,21 @@ def guided_choice_completion(client: OpenAI, model: str):
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"End in .com and new line. Example result:" "Generate an email address for Alan Turing, who works in Enigma."
"alan.turing@enigma.com\n") "End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
extra_body={ }
"guided_regex": r"\w+@\w+\.com\n", ],
"stop": ["\n"] extra_body={"guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"]},
},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -66,14 +66,18 @@ class CarDescription(BaseModel): ...@@ -66,14 +66,18 @@ class CarDescription(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -95,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -95,14 +99,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
return completion.choices[0].message.content return completion.choices[0].message.content
...@@ -110,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -110,19 +118,23 @@ def guided_grammar_completion(client: OpenAI, model: str):
# Extra backend options # Extra backend options
def extra_backend_options_completion(client: OpenAI, model: str): def extra_backend_options_completion(client: OpenAI, model: str):
prompt = ("Generate an email address for Alan Turing, who works in Enigma." prompt = (
"End in .com and new line. Example result:" "Generate an email address for Alan Turing, who works in Enigma."
"alan.turing@enigma.com\n") "End in .com and new line. Example result:"
"alan.turing@enigma.com\n"
)
try: try:
# The guided_decoding_disable_fallback option forces vLLM to use # The guided_decoding_disable_fallback option forces vLLM to use
# xgrammar, so when it fails you get a 400 with the reason why # xgrammar, so when it fails you get a 400 with the reason why
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={ extra_body={
"guided_regex": r"\w+@\w+\.com\n", "guided_regex": r"\w+@\w+\.com\n",
"stop": ["\n"], "stop": ["\n"],
......
...@@ -17,11 +17,10 @@ def main(): ...@@ -17,11 +17,10 @@ def main():
api_key=openai_api_key, api_key=openai_api_key,
) )
messages = [{ messages = [
"role": {
"user", "role": "user",
"content": "content": """
"""
You have access to the following function to retrieve the weather in a city: You have access to the following function to retrieve the weather in a city:
{ {
...@@ -58,29 +57,28 @@ You are a helpful assistant. ...@@ -58,29 +57,28 @@ You are a helpful assistant.
Given the previous instructions, what is the weather in New York City, Boston, Given the previous instructions, what is the weather in New York City, Boston,
and San Francisco? and San Francisco?
""" """,
}] }
]
response = client.chat.completions.create( response = client.chat.completions.create(
model=client.models.list().data[0].id, model=client.models.list().data[0].id,
messages=messages, messages=messages,
response_format={ response_format={
"type": "type": "structural_tag",
"structural_tag", "structures": [
"structures": [{ {
"begin": "<function=get_weather>", "begin": "<function=get_weather>",
"schema": { "schema": {
"type": "object", "type": "object",
"properties": { "properties": {"city": {"type": "string"}},
"city": { },
"type": "string" "end": "</function>",
} }
} ],
}, "triggers": ["<function="],
"end": "</function>" },
}], )
"triggers": ["<function="]
})
print(response) print(response)
......
...@@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1" ...@@ -27,21 +27,22 @@ openai_api_base = "http://localhost:8000/v1"
def print_completion_details(completion): def print_completion_details(completion):
print("reasoning_content: ", print("reasoning_content: ", completion.choices[0].message.reasoning_content)
completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content) print("content: ", completion.choices[0].message.content)
# Guided decoding by Regex # Guided decoding by Regex
def guided_regex_completion(client: OpenAI, model: str): def guided_regex_completion(client: OpenAI, model: str):
prompt = ("What is the capital of France?") prompt = "What is the capital of France?"
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={ extra_body={
"guided_regex": "(Paris|London)", "guided_regex": "(Paris|London)",
}, },
...@@ -57,13 +58,15 @@ class People(BaseModel): ...@@ -57,13 +58,15 @@ class People(BaseModel):
def guided_json_completion(client: OpenAI, model: str): def guided_json_completion(client: OpenAI, model: str):
json_schema = People.model_json_schema() json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.") prompt = "Generate a JSON with the name and age of one random person."
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
...@@ -86,14 +89,18 @@ class CarDescription(BaseModel): ...@@ -86,14 +89,18 @@ class CarDescription(BaseModel):
def guided_car_json_completion(client: OpenAI, model: str): def guided_car_json_completion(client: OpenAI, model: str):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of" prompt = (
"the most iconic car from the 90's") "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's"
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_json": json_schema}, extra_body={"guided_json": json_schema},
) )
print_completion_details(completion) print_completion_details(completion)
...@@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str): ...@@ -116,14 +123,18 @@ def guided_grammar_completion(client: OpenAI, model: str):
""" """
# This may be very slow https://github.com/vllm-project/vllm/issues/12122 # This may be very slow https://github.com/vllm-project/vllm/issues/12122
prompt = ("Generate an SQL query to show the 'username' and 'email'" prompt = (
"from the 'users' table.") "Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table."
)
completion = client.chat.completions.create( completion = client.chat.completions.create(
model=model, model=model,
messages=[{ messages=[
"role": "user", {
"content": prompt, "role": "user",
}], "content": prompt,
}
],
extra_body={"guided_grammar": simplified_sql_grammar}, extra_body={"guided_grammar": simplified_sql_grammar},
) )
print_completion_details(completion) print_completion_details(completion)
......
...@@ -20,9 +20,11 @@ from openai import OpenAI ...@@ -20,9 +20,11 @@ from openai import OpenAI
# Now, simulate a tool call # Now, simulate a tool call
def get_current_weather(city: str, state: str, unit: 'str'): def get_current_weather(city: str, state: str, unit: "str"):
return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " return (
"partly cloudly, with highs in the 90's.") "The weather in Dallas, Texas is 85 degrees fahrenheit. It is "
"partly cloudly, with highs in the 90's."
)
available_tools = {"get_current_weather": get_current_weather} available_tools = {"get_current_weather": get_current_weather}
...@@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather} ...@@ -31,49 +33,47 @@ available_tools = {"get_current_weather": get_current_weather}
openai_api_key = "EMPTY" openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1" openai_api_base = "http://localhost:8000/v1"
tools = [{ properties = {
"type": "function", "city": {
"function": { "type": "string",
"name": "get_current_weather", "description": "The city to find the weather for, e.g. 'San Francisco'",
"description": "Get the current weather in a given location", },
"parameters": { "state": {
"type": "object", "type": "string",
"properties": { "description": "the two-letter abbreviation for the state that the city is"
"city": { " in, e.g. 'CA' which would mean 'California'",
"type": },
"string", "unit": {
"description": "type": "string",
"The city to find the weather for, e.g. 'San Francisco'" "description": "The unit to fetch the temperature in",
}, "enum": ["celsius", "fahrenheit"],
"state": { },
"type": }
"string",
"description": tools = [
"the two-letter abbreviation for the state that the city is" {
" in, e.g. 'CA' which would mean 'California'" "type": "function",
}, "function": {
"unit": { "name": "get_current_weather",
"type": "string", "description": "Get the current weather in a given location",
"description": "The unit to fetch the temperature in", "parameters": {
"enum": ["celsius", "fahrenheit"] "type": "object",
} "properties": properties,
"required": ["city", "state", "unit"],
}, },
"required": ["city", "state", "unit"] },
}
} }
}] ]
messages = [{ messages = [
"role": "user", {"role": "user", "content": "Hi! How are you doing today?"},
"content": "Hi! How are you doing today?" {"role": "assistant", "content": "I'm doing well! How can I help you?"},
}, { {
"role": "assistant", "role": "user",
"content": "I'm doing well! How can I help you?" "content": (
}, { "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
"role": ),
"user", },
"content": ]
"Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}]
def extract_reasoning_and_calls(chunks: list): def extract_reasoning_and_calls(chunks: list):
...@@ -110,73 +110,55 @@ def main(): ...@@ -110,73 +110,55 @@ def main():
models = client.models.list() models = client.models.list()
model = models.data[0].id model = models.data[0].id
print("---------Full Generate With Automatic Function Calling-------------")
tool_calls = client.chat.completions.create(
messages=messages, model=model, tools=tools
)
print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
print(f"function name: {tool_calls.choices[0].message.tool_calls[0].function.name}")
print( print(
"---------Full Generate With Automatic Function Calling-------------") f"function arguments: "
tool_calls = client.chat.completions.create(messages=messages, f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}"
model=model,
tools=tools)
print(
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
) )
print(f"function name: "
f"{tool_calls.choices[0].message.tool_calls[0].function.name}")
print(f"function arguments: "
f"{tool_calls.choices[0].message.tool_calls[0].function.arguments}")
print( print("----------Stream Generate With Automatic Function Calling-----------")
"----------Stream Generate With Automatic Function Calling-----------") tool_calls_stream = client.chat.completions.create(
tool_calls_stream = client.chat.completions.create(messages=messages, messages=messages, model=model, tools=tools, stream=True
model=model, )
tools=tools,
stream=True)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")
print( print("----------Full Generate With Named Function Calling-----------------")
"----------Full Generate With Named Function Calling-----------------") tool_calls = client.chat.completions.create(
tool_calls = client.chat.completions.create(messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", )
"function": {
"name":
"get_current_weather"
}
})
tool_call = tool_calls.choices[0].message.tool_calls[0].function tool_call = tool_calls.choices[0].message.tool_calls[0].function
print( print(f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}")
f"reasoning_content: {tool_calls.choices[0].message.reasoning_content}"
)
print(f"function name: {tool_call.name}") print(f"function name: {tool_call.name}")
print(f"function arguments: {tool_call.arguments}") print(f"function arguments: {tool_call.arguments}")
print( print("----------Stream Generate With Named Function Calling--------------")
"----------Stream Generate With Named Function Calling--------------")
tool_calls_stream = client.chat.completions.create( tool_calls_stream = client.chat.completions.create(
messages=messages, messages=messages,
model=model, model=model,
tools=tools, tools=tools,
tool_choice={ tool_choice={"type": "function", "function": {"name": "get_current_weather"}},
"type": "function", stream=True,
"function": { )
"name": "get_current_weather"
}
},
stream=True)
chunks = list(tool_calls_stream) chunks = list(tool_calls_stream)
reasoning_content, arguments, function_names = extract_reasoning_and_calls( reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
chunks)
print(f"reasoning_content: {reasoning_content}") print(f"reasoning_content: {reasoning_content}")
print(f"function name: {function_names[0]}") print(f"function name: {function_names[0]}")
print(f"function arguments: {arguments[0]}") print(f"function arguments: {arguments[0]}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment