Commit 006693ed authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.11.2' into v0.11.2-ori

parents 4b51e6f1 275de341
...@@ -13,7 +13,7 @@ from tqdm import tqdm ...@@ -13,7 +13,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000)) DURATION_MS = int(os.getenv("VLLM_TPU_PROFILE_DURATION_MS", 3000))
DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0)) DELAY_MS = int(os.getenv("VLLM_TPU_PROFILE_DELAY_MS", 0))
......
...@@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \ ...@@ -11,12 +11,10 @@ python examples/offline_inference/qwen2_5_omni/only_thinker.py \
# Read vision and audio inputs from a single video file # Read vision and audio inputs from a single video file
# NOTE: V1 engine does not support interleaved modalities yet. # NOTE: V1 engine does not support interleaved modalities yet.
VLLM_USE_V1=0 \
python examples/offline_inference/qwen2_5_omni/only_thinker.py \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \
-q use_audio_in_video -q use_audio_in_video
# Multiple audios # Multiple audios
VLLM_USE_V1=0 \
python examples/offline_inference/qwen2_5_omni/only_thinker.py \ python examples/offline_inference/qwen2_5_omni/only_thinker.py \
-q multi_audios -q multi_audios
``` ```
......
...@@ -7,13 +7,12 @@ with the correct prompt format on Qwen2.5-Omni (thinker only). ...@@ -7,13 +7,12 @@ with the correct prompt format on Qwen2.5-Omni (thinker only).
from typing import NamedTuple from typing import NamedTuple
import vllm.envs as envs
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
class QueryResult(NamedTuple): class QueryResult(NamedTuple):
...@@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult: ...@@ -72,11 +71,7 @@ def get_use_audio_in_video_query() -> QueryResult:
) )
asset = VideoAsset(name="baby_reading", num_frames=16) asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000) audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, (
"V1 does not support use_audio_in_video. "
"Please launch this example with "
"`VLLM_USE_V1=0`."
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
......
...@@ -38,7 +38,7 @@ from rlhf_utils import stateless_init_process_group ...@@ -38,7 +38,7 @@ from rlhf_utils import stateless_init_process_group
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import get_ip, get_open_port from vllm.utils.network_utils import get_ip, get_open_port
class MyLLM(LLM): class MyLLM(LLM):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
from typing import Callable, Optional, TypedDict from collections.abc import Callable
from typing import TypedDict
import torch import torch
import zmq import zmq
...@@ -71,7 +72,7 @@ class WorkerExtension: ...@@ -71,7 +72,7 @@ class WorkerExtension:
def rebuild_ipc( def rebuild_ipc(
handle: tuple[Callable, tuple], device_id: Optional[int] = None handle: tuple[Callable, tuple], device_id: int | None = None
) -> torch.Tensor: ) -> torch.Tensor:
func, args = handle func, args = handle
list_args = list(args) list_args = list(args)
...@@ -109,7 +110,7 @@ class ColocateWorkerExtension: ...@@ -109,7 +110,7 @@ class ColocateWorkerExtension:
self._zmq_ctx = zmq.Context() self._zmq_ctx = zmq.Context()
socket = self._zmq_ctx.socket(zmq.REP) socket = self._zmq_ctx.socket(zmq.REP)
socket.connect(zmq_handles[self.report_device_id()]) socket.connect(zmq_handles[self.report_device_id()])
buffer: Optional[torch.Tensor] = None buffer: torch.Tensor | None = None
while True: while True:
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = ( payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
socket.recv_pyobj() socket.recv_pyobj()
......
...@@ -30,7 +30,7 @@ from pathlib import Path ...@@ -30,7 +30,7 @@ from pathlib import Path
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.model_executor.model_loader import ShardedStateLoader from vllm.model_executor.model_loader import ShardedStateLoader
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
def parse_args(): def parse_args():
......
...@@ -9,25 +9,25 @@ from vllm.inputs import TokensPrompt ...@@ -9,25 +9,25 @@ from vllm.inputs import TokensPrompt
from vllm.v1.metrics.reader import Counter, Vector from vllm.v1.metrics.reader import Counter, Vector
try: try:
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
except ImportError: except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser from argparse import ArgumentParser as FlexibleArgumentParser
QUESTION = "What is the content of each image?" QUESTION = "What is the content of each image?"
IMAGE_URLS = [ IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
] ]
......
...@@ -4,17 +4,88 @@ ...@@ -4,17 +4,88 @@
experimental support for data-parallel inference with torchrun experimental support for data-parallel inference with torchrun
Note the data load balancing and distribution is done out of the vllm engine, Note the data load balancing and distribution is done out of the vllm engine,
no internal lb supported in external_launcher mode. no internal lb supported in external_launcher mode.
To run this example:
```bash
$ torchrun --nproc-per-node=2 examples/offline_inference/torchrun_dp_example.py
```
With custom parallelism settings:
```bash
$ torchrun --nproc-per-node=8 examples/offline_inference/torchrun_dp_example.py \
--tp-size=2 --pp-size=1 --dp-size=4 --enable-ep
```
""" """
import argparse
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
def parse_args():
parser = argparse.ArgumentParser(
description="Data-parallel inference with torchrun"
)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallel size (default: 1)",
)
parser.add_argument(
"--pp-size",
type=int,
default=1,
help="Pipeline parallel size (default: 1)",
)
parser.add_argument(
"--dp-size",
type=int,
default=2,
help="Data parallel size (default: 2)",
)
parser.add_argument(
"--enable-ep",
action="store_true",
help="Enable expert parallel (default: False)",
)
parser.add_argument(
"--model",
type=str,
default="microsoft/Phi-mini-MoE-instruct",
help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
)
parser.add_argument(
"--max-model-len",
type=int,
default=4096,
help="Maximum model length (default: 4096)",
)
parser.add_argument(
"--gpu-memory-utilization",
type=float,
default=0.6,
help="GPU memory utilization (default: 0.6)",
)
parser.add_argument(
"--seed",
type=int,
default=1,
help="Random seed (default: 1)",
)
return parser.parse_args()
args = parse_args()
# Create prompts, the same across all ranks # Create prompts, the same across all ranks
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] * 50 ]
# Create sampling parameters, the same across all ranks # Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.8, top_p=0.95) sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
...@@ -25,15 +96,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ...@@ -25,15 +96,15 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# all ranks have the same random seed, so that sampling can be # all ranks have the same random seed, so that sampling can be
# deterministic across ranks. # deterministic across ranks.
llm = LLM( llm = LLM(
model="microsoft/Phi-mini-MoE-instruct", model=args.model,
tensor_parallel_size=1, tensor_parallel_size=args.tp_size,
data_parallel_size=2, data_parallel_size=args.dp_size,
pipeline_parallel_size=1, pipeline_parallel_size=args.pp_size,
enable_expert_parallel=False, enable_expert_parallel=args.enable_ep,
distributed_executor_backend="external_launcher", distributed_executor_backend="external_launcher",
max_model_len=4096, max_model_len=args.max_model_len,
gpu_memory_utilization=0.6, gpu_memory_utilization=args.gpu_memory_utilization,
seed=1, seed=args.seed,
) )
dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
...@@ -45,14 +116,13 @@ prompts = [ ...@@ -45,14 +116,13 @@ prompts = [
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
# all ranks will have the same outputs
print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n") print(
print("-" * 50) f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
)
""" """
Further tips: Further tips:
......
...@@ -12,7 +12,7 @@ import os ...@@ -12,7 +12,7 @@ import os
import random import random
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import NamedTuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -22,14 +22,15 @@ from vllm.assets.image import ImageAsset ...@@ -22,14 +22,15 @@ from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
class ModelRequestData(NamedTuple): class ModelRequestData(NamedTuple):
engine_args: EngineArgs engine_args: EngineArgs
prompts: list[str] prompts: list[str]
stop_token_ids: Optional[list[int]] = None stop_token_ids: list[int] | None = None
lora_requests: Optional[list[LoRARequest]] = None lora_requests: list[LoRARequest] | None = None
sampling_params: list[SamplingParams] | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
...@@ -90,16 +91,25 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: ...@@ -90,16 +91,25 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData:
) )
# BLIP-2 # Bee-8B
def run_blip2(questions: list[str], modality: str) -> ModelRequestData: def run_bee(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
model_name = "Open-Bee/Bee-8B-RL"
prompts = [
(
f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<image>\n{question}<|im_end|>"
f"<|im_start|>assistant\n<think>\n"
)
for question in questions
]
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs( engine_args = EngineArgs(
model="Salesforce/blip2-opt-2.7b", model=model_name,
max_model_len=16384,
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
) )
return ModelRequestData( return ModelRequestData(
...@@ -108,15 +118,15 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: ...@@ -108,15 +118,15 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
) )
# Chameleon # BLIP-2
def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: def run_blip2(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
prompts = [f"{question}<image>" for question in questions] # BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompts = [f"Question: {question} Answer:" for question in questions]
engine_args = EngineArgs( engine_args = EngineArgs(
model="facebook/chameleon-7b", model="Salesforce/blip2-opt-2.7b",
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
...@@ -126,15 +136,16 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: ...@@ -126,15 +136,16 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
) )
# Dots-OCR # Chameleon
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData: def run_chameleon(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions] prompts = [f"{question}<image>" for question in questions]
engine_args = EngineArgs( engine_args = EngineArgs(
model="rednote-hilab/dots.ocr", model="facebook/chameleon-7b",
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
) )
return ModelRequestData( return ModelRequestData(
...@@ -190,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ...@@ -190,6 +201,66 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData:
) )
def run_deepseek_ocr(questions: list[str], modality: str) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
assert modality == "image"
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
limit_mm_per_prompt={modality: 1},
logits_processors=[NGramPerReqLogitsProcessor],
)
# deepseek-ocr use plain prompt template
prompts = [f"<image>\n{question}" for question in questions]
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = [
SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
for _ in questions
]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
sampling_params=sampling_params,
)
# Dots-OCR
def run_dots_ocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [f"<|img|><|imgpad|><|endofimg|>{question}" for question in questions]
engine_args = EngineArgs(
model="rednote-hilab/dots.ocr",
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# Ernie4.5-VL # Ernie4.5-VL
def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData:
model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT"
...@@ -576,7 +647,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: ...@@ -576,7 +647,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# Intern-S1 # Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData: def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
model_name = "internlm/Intern-S1" model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
...@@ -733,6 +804,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: ...@@ -733,6 +804,26 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
) )
# LightOnOCR
def run_lightonocr(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
prompts = [
"<|im_start|>system<|im_end|>\n<|im_start|>user\n<|image_pad|><|im_end|>\n<|im_start|>assistant\n"
for _ in questions
]
engine_args = EngineArgs(
model="lightonai/LightOnOCR-1B",
limit_mm_per_prompt={modality: 1},
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
def run_llama4(questions: list[str], modality: str) -> ModelRequestData: def run_llama4(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
...@@ -1140,15 +1231,37 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData: ...@@ -1140,15 +1231,37 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video": elif modality == "video":
placeholder = "<video>" placeholder = "<video>"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) prompts = [
messages = [ f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n"
[{"role": "user", "content": f"{placeholder}\n{question}"}]
for question in questions for question in questions
] ]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
) )
# PaddleOCR-VL
def run_paddleocr_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "PaddlePaddle/PaddleOCR-VL"
engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=2,
limit_mm_per_prompt={modality: 1},
trust_remote_code=True,
)
placeholder = "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>"
prompts = [
(f"<|begin_of_sentence|>User: {question}{placeholder}\nAssistant: ")
for question in questions
]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompts=prompts, prompts=prompts,
...@@ -1423,7 +1536,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ...@@ -1423,7 +1536,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
mm_processor_kwargs={ mm_processor_kwargs={
"min_pixels": 28 * 28, "min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28, "max_pixels": 1280 * 28 * 28,
"fps": [1], "fps": 1,
}, },
limit_mm_per_prompt={modality: 1}, limit_mm_per_prompt={modality: 1},
) )
...@@ -1691,11 +1804,13 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: ...@@ -1691,11 +1804,13 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
model_example_map = { model_example_map = {
"aria": run_aria, "aria": run_aria,
"aya_vision": run_aya_vision, "aya_vision": run_aya_vision,
"bee": run_bee,
"blip-2": run_blip2, "blip-2": run_blip2,
"chameleon": run_chameleon, "chameleon": run_chameleon,
"dots_ocr": run_dots_ocr,
"command_a_vision": run_command_a_vision, "command_a_vision": run_command_a_vision,
"deepseek_vl_v2": run_deepseek_vl2, "deepseek_vl_v2": run_deepseek_vl2,
"deepseek_ocr": run_deepseek_ocr,
"dots_ocr": run_dots_ocr,
"ernie45_vl": run_ernie45_vl, "ernie45_vl": run_ernie45_vl,
"fuyu": run_fuyu, "fuyu": run_fuyu,
"gemma3": run_gemma3, "gemma3": run_gemma3,
...@@ -1712,6 +1827,7 @@ model_example_map = { ...@@ -1712,6 +1827,7 @@ model_example_map = {
"keye_vl": run_keye_vl, "keye_vl": run_keye_vl,
"keye_vl1_5": run_keye_vl1_5, "keye_vl1_5": run_keye_vl1_5,
"kimi_vl": run_kimi_vl, "kimi_vl": run_kimi_vl,
"lightonocr": run_lightonocr,
"llama4": run_llama4, "llama4": run_llama4,
"llava": run_llava, "llava": run_llava,
"llava-next": run_llava_next, "llava-next": run_llava_next,
...@@ -1727,6 +1843,7 @@ model_example_map = { ...@@ -1727,6 +1843,7 @@ model_example_map = {
"NVLM_D": run_nvlm_d, "NVLM_D": run_nvlm_d,
"ovis": run_ovis, "ovis": run_ovis,
"ovis2_5": run_ovis2_5, "ovis2_5": run_ovis2_5,
"paddleocr_vl": run_paddleocr_vl,
"paligemma": run_paligemma, "paligemma": run_paligemma,
"paligemma2": run_paligemma2, "paligemma2": run_paligemma2,
"phi3_v": run_phi3v, "phi3_v": run_phi3v,
...@@ -1957,8 +2074,12 @@ def main(args): ...@@ -1957,8 +2074,12 @@ def main(args):
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams( sampling_params = (
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
) )
assert args.num_prompts > 0 assert args.num_prompts > 0
......
...@@ -9,7 +9,7 @@ using the chat template defined by the model. ...@@ -9,7 +9,7 @@ using the chat template defined by the model.
import os import os
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
from typing import NamedTuple, Optional from typing import NamedTuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL.Image import Image from PIL.Image import Image
...@@ -18,22 +18,22 @@ from transformers import AutoProcessor, AutoTokenizer ...@@ -18,22 +18,22 @@ from transformers import AutoProcessor, AutoTokenizer
from vllm import LLM, EngineArgs, SamplingParams from vllm import LLM, EngineArgs, SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
QUESTION = "What is the content of each image?" QUESTION = "What is the content of each image?"
IMAGE_URLS = [ IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/duck.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/lion.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/flycatcher.jpeg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/somefish.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/starfish.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/snail.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/thistle.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/husky.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/orangetabbycat.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/guineapig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/rabbit.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", "https://vllm-public-assets.s3.us-west-2.amazonaws.com/multimodal_asset/horsepony.jpg",
] ]
...@@ -41,9 +41,10 @@ class ModelRequestData(NamedTuple): ...@@ -41,9 +41,10 @@ class ModelRequestData(NamedTuple):
engine_args: EngineArgs engine_args: EngineArgs
prompt: str prompt: str
image_data: list[Image] image_data: list[Image]
stop_token_ids: Optional[list[int]] = None stop_token_ids: list[int] | None = None
chat_template: Optional[str] = None chat_template: str | None = None
lora_requests: Optional[list[LoRARequest]] = None lora_requests: list[LoRARequest] | None = None
sampling_params: SamplingParams | None = None
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
...@@ -107,6 +108,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -107,6 +108,41 @@ def load_aya_vision(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_bee(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Open-Bee/Bee-8B-RL"
engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
max_num_seqs=16,
limit_mm_per_prompt={"image": len(image_urls)},
trust_remote_code=True,
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData: def load_command_a_vision(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "CohereLabs/command-a-vision-07-2025" model_name = "CohereLabs/command-a-vision-07-2025"
...@@ -166,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -166,6 +202,46 @@ def load_deepseek_vl2(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_deepseek_ocr(question: str, image_urls: list[str]) -> ModelRequestData:
from vllm.model_executor.models.deepseek_ocr import NGramPerReqLogitsProcessor
model_name = "deepseek-ai/DeepSeek-OCR"
engine_args = EngineArgs(
model=model_name,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
logits_processors=[NGramPerReqLogitsProcessor],
)
placeholder = "<image>\n" * len(image_urls)
prompt = placeholder + question
# The following sampling params config is taken from
# the official Deepseek-OCR inference example.
# (IMPORTANT) Use the custom logits processor and avoid skipping
# special tokens for this model for the optimal OCR performance.
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=8192,
# ngram logit processor args
extra_args=dict(
ngram_size=30,
window_size=90,
# whitelist: <td>, </td>
whitelist_token_ids={128821, 128822},
),
skip_special_tokens=False,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
sampling_params=sampling_params,
)
def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData: def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "google/gemma-3-4b-it" model_name = "google/gemma-3-4b-it"
...@@ -309,7 +385,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -309,7 +385,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData: def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "internlm/Intern-S1" model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
...@@ -371,13 +447,14 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -371,13 +447,14 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model_name = "Kwai-Keye/Keye-VL-8B-Preview"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=131072, trust_remote_code=True,
tensor_parallel_size=8, max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -389,29 +466,32 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -389,29 +466,32 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
*placeholders, *placeholders,
{"type": "text", "text": question}, {"type": "text", "text": question},
], ],
} },
] ]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
image_data=[fetch_image(url) for url in image_urls], image_data=image_data,
) )
def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
# NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs, model_name = "Kwai-Keye/Keye-VL-1_5-8B"
# it will generate poor response for multi-image inputs!
model_name = "llava-hf/llava-1.5-7b-hf"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_num_seqs=16, trust_remote_code=True,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -423,28 +503,32 @@ def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -423,28 +503,32 @@ def load_llava(question: str, image_urls: list[str]) -> ModelRequestData:
*placeholders, *placeholders,
{"type": "text", "text": question}, {"type": "text", "text": question},
], ],
} },
] ]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
image_data=[fetch_image(url) for url in image_urls], image_data=image_data,
) )
def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "llava-hf/llava-v1.6-mistral-7b-hf" model_name = "moonshotai/Kimi-VL-A3B-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=8192, trust_remote_code=True,
max_num_seqs=16, max_model_len=4096,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -459,7 +543,7 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -459,7 +543,7 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
} }
] ]
processor = AutoProcessor.from_pretrained(model_name) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
...@@ -472,12 +556,13 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -472,12 +556,13 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData: def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf" model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
max_model_len=16384, max_model_len=131072,
max_num_seqs=16, tensor_parallel_size=8,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -505,14 +590,13 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa ...@@ -505,14 +590,13 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa
) )
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: def load_llava(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-8B-Preview" # NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs,
# it will generate poor response for multi-image inputs!
model_name = "llava-hf/llava-1.5-7b-hf"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
trust_remote_code=True, max_num_seqs=16,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -524,32 +608,28 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -524,32 +608,28 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
*placeholders, *placeholders,
{"type": "text", "text": question}, {"type": "text", "text": question},
], ],
}, }
] ]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
image_data=image_data, image_data=[fetch_image(url) for url in image_urls],
) )
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "Kwai-Keye/Keye-VL-1_5-8B" model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
trust_remote_code=True,
max_model_len=8192, max_model_len=8192,
max_num_seqs=5, max_num_seqs=16,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -561,32 +641,28 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -561,32 +641,28 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
*placeholders, *placeholders,
{"type": "text", "text": question}, {"type": "text", "text": question},
], ],
}, }
] ]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
image_data = [fetch_image(url) for url in image_urls]
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
image_data=image_data, image_data=[fetch_image(url) for url in image_urls],
) )
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "moonshotai/Kimi-VL-A3B-Instruct" model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,
trust_remote_code=True, max_model_len=16384,
max_model_len=4096, max_num_seqs=16,
max_num_seqs=4,
limit_mm_per_prompt={"image": len(image_urls)}, limit_mm_per_prompt={"image": len(image_urls)},
) )
...@@ -601,7 +677,7 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -601,7 +677,7 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
} }
] ]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template( prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
...@@ -713,13 +789,32 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -713,13 +789,32 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "\n".join( placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1) f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
) )
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}] prompt = (
f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) return ModelRequestData(
prompt = tokenizer.apply_chat_template( engine_args=engine_args,
messages, tokenize=False, add_generation_prompt=True prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_paddleocr_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "PaddlePaddle/PaddleOCR-VL"
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=2,
limit_mm_per_prompt={"image": len(image_urls)},
) )
placeholders = "<|IMAGE_START|><|IMAGE_PLACEHOLDER|><|IMAGE_END|>" * len(image_urls)
prompt = f"<|begin_of_sentence|>User: {question}{placeholders}\nAssistant: "
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
prompt=prompt, prompt=prompt,
...@@ -1217,8 +1312,10 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData: ...@@ -1217,8 +1312,10 @@ def load_glm4_5v_fp8(question: str, image_urls: list[str]) -> ModelRequestData:
model_example_map = { model_example_map = {
"aria": load_aria, "aria": load_aria,
"aya_vision": load_aya_vision, "aya_vision": load_aya_vision,
"bee": load_bee,
"command_a_vision": load_command_a_vision, "command_a_vision": load_command_a_vision,
"deepseek_vl_v2": load_deepseek_vl2, "deepseek_vl_v2": load_deepseek_vl2,
"deepseek_ocr": load_deepseek_ocr,
"gemma3": load_gemma3, "gemma3": load_gemma3,
"h2ovl_chat": load_h2ovl, "h2ovl_chat": load_h2ovl,
"hyperclovax_seed_vision": load_hyperclovax_seed_vision, "hyperclovax_seed_vision": load_hyperclovax_seed_vision,
...@@ -1236,6 +1333,7 @@ model_example_map = { ...@@ -1236,6 +1333,7 @@ model_example_map = {
"NVLM_D": load_nvlm_d, "NVLM_D": load_nvlm_d,
"ovis": load_ovis, "ovis": load_ovis,
"ovis2_5": load_ovis2_5, "ovis2_5": load_ovis2_5,
"paddleocr_vl": load_paddleocr_vl,
"phi3_v": load_phi3v, "phi3_v": load_phi3v,
"phi4_mm": load_phi4mm, "phi4_mm": load_phi4mm,
"phi4_multimodal": load_phi4_multimodal, "phi4_multimodal": load_phi4_multimodal,
...@@ -1253,7 +1351,7 @@ model_example_map = { ...@@ -1253,7 +1351,7 @@ model_example_map = {
} }
def run_generate(model, question: str, image_urls: list[str], seed: Optional[int]): def run_generate(model, question: str, image_urls: list[str], seed: int | None):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
engine_args = asdict(req_data.engine_args) | {"seed": args.seed} engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
...@@ -1279,7 +1377,7 @@ def run_generate(model, question: str, image_urls: list[str], seed: Optional[int ...@@ -1279,7 +1377,7 @@ def run_generate(model, question: str, image_urls: list[str], seed: Optional[int
print("-" * 50) print("-" * 50)
def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[int]): def run_chat(model: str, question: str, image_urls: list[str], seed: int | None):
req_data = model_example_map[model](question, image_urls) req_data = model_example_map[model](question, image_urls)
# Disable other modalities to save memory # Disable other modalities to save memory
...@@ -1291,8 +1389,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[in ...@@ -1291,8 +1389,12 @@ def run_chat(model: str, question: str, image_urls: list[str], seed: Optional[in
engine_args = asdict(req_data.engine_args) | {"seed": seed} engine_args = asdict(req_data.engine_args) | {"seed": seed}
llm = LLM(**engine_args) llm = LLM(**engine_args)
sampling_params = SamplingParams( sampling_params = (
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=req_data.stop_token_ids
)
if req_data.sampling_params is None
else req_data.sampling_params
) )
outputs = llm.chat( outputs = llm.chat(
[ [
......
...@@ -10,14 +10,18 @@ on HuggingFace model repository. ...@@ -10,14 +10,18 @@ on HuggingFace model repository.
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict from dataclasses import asdict
from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args from pathlib import Path
from typing import Literal, NamedTuple, TypeAlias, TypedDict, get_args
from PIL.Image import Image from PIL.Image import Image
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.entrypoints.score_utils import ScoreMultiModalParam
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
ROOT_DIR = Path(__file__).parent.parent.parent
EXAMPLES_DIR = ROOT_DIR / "examples"
class TextQuery(TypedDict): class TextQuery(TypedDict):
...@@ -43,15 +47,39 @@ class TextImagesQuery(TypedDict): ...@@ -43,15 +47,39 @@ class TextImagesQuery(TypedDict):
QueryModality = Literal["text", "image", "text+image", "text+images"] QueryModality = Literal["text", "image", "text+image", "text+images"]
Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery] Query: TypeAlias = TextQuery | ImageQuery | TextImageQuery | TextImagesQuery
class ModelRequestData(NamedTuple): class ModelRequestData(NamedTuple):
engine_args: EngineArgs engine_args: EngineArgs
prompt: Optional[str] = None prompt: str | None = None
image: Optional[Image] = None image: Image | None = None
query: Optional[str] = None query: str | None = None
documents: Optional[ScoreMultiModalParam] = None documents: ScoreMultiModalParam | None = None
def run_clip(query: Query) -> ModelRequestData:
if query["modality"] == "text":
prompt = query["text"]
image = None
elif query["modality"] == "image":
prompt = "" # For image input, make sure that the prompt text is empty
image = query["image"]
else:
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs(
model="openai/clip-vit-base-patch32",
runner="pooling",
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image=image,
)
def run_e5_v(query: Query) -> ModelRequestData: def run_e5_v(query: Query) -> ModelRequestData:
...@@ -82,23 +110,74 @@ def run_e5_v(query: Query) -> ModelRequestData: ...@@ -82,23 +110,74 @@ def run_e5_v(query: Query) -> ModelRequestData:
) )
def run_vlm2vec(query: Query) -> ModelRequestData: def run_jinavl_reranker(query: Query) -> ModelRequestData:
if query["modality"] != "text+images":
raise ValueError(f"Unsupported query modality: '{query['modality']}'")
engine_args = EngineArgs(
model="jinaai/jina-reranker-m0",
runner="pooling",
max_model_len=32768,
trust_remote_code=True,
mm_processor_kwargs={
"min_pixels": 3136,
"max_pixels": 602112,
},
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
query=query["text"],
documents=query["image"],
)
def run_siglip(query: Query) -> ModelRequestData:
if query["modality"] == "text":
prompt = query["text"]
image = None
elif query["modality"] == "image":
prompt = "" # For image input, make sure that the prompt text is empty
image = query["image"]
else:
modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'")
engine_args = EngineArgs(
model="google/siglip-base-patch16-224",
runner="pooling",
limit_mm_per_prompt={"image": 1},
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image=image,
)
def _get_vlm2vec_prompt_image(query: Query, image_token: str):
if query["modality"] == "text": if query["modality"] == "text":
text = query["text"] text = query["text"]
prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501 prompt = f"Find me an everyday image that matches the given caption: {text}"
image = None image = None
elif query["modality"] == "image": elif query["modality"] == "image":
prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501 prompt = f"{image_token} Find a day-to-day image that looks similar to the provided image." # noqa: E501
image = query["image"] image = query["image"]
elif query["modality"] == "text+image": elif query["modality"] == "text+image":
text = query["text"] text = query["text"]
prompt = ( prompt = f"{image_token} Represent the given image with the following question: {text}" # noqa: E501
f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501
)
image = query["image"] image = query["image"]
else: else:
modality = query["modality"] modality = query["modality"]
raise ValueError(f"Unsupported query modality: '{modality}'") raise ValueError(f"Unsupported query modality: {modality!r}")
return prompt, image
def run_vlm2vec_phi3v(query: Query) -> ModelRequestData:
prompt, image = _get_vlm2vec_prompt_image(query, "<|image_1|>")
engine_args = EngineArgs( engine_args = EngineArgs(
model="TIGER-Lab/VLM2Vec-Full", model="TIGER-Lab/VLM2Vec-Full",
...@@ -116,26 +195,66 @@ def run_vlm2vec(query: Query) -> ModelRequestData: ...@@ -116,26 +195,66 @@ def run_vlm2vec(query: Query) -> ModelRequestData:
) )
def run_jinavl_reranker(query: Query) -> ModelRequestData: def run_vlm2vec_qwen2vl(query: Query) -> ModelRequestData:
if query["modality"] != "text+images": # vLLM does not support LoRA adapters on multi-modal encoder,
raise ValueError(f"Unsupported query modality: '{query['modality']}'") # so we merge the weights first
from huggingface_hub.constants import HF_HUB_CACHE
from peft import PeftConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from vllm.entrypoints.chat_utils import load_chat_template
model_id = "TIGER-Lab/VLM2Vec-Qwen2VL-2B"
base_model = AutoModelForImageTextToText.from_pretrained(model_id)
lora_model = PeftModel.from_pretrained(
base_model,
model_id,
config=PeftConfig.from_pretrained(model_id),
)
model = lora_model.merge_and_unload().to(dtype=base_model.dtype)
model._hf_peft_config_loaded = False # Needed to save the merged model
processor = AutoProcessor.from_pretrained(
model_id,
# `min_pixels` and `max_pixels` are deprecated for
# transformers `preprocessor_config.json`
size={"shortest_edge": 3136, "longest_edge": 12845056},
)
processor.chat_template = load_chat_template(
# The original chat template is not correct
EXAMPLES_DIR / "template_vlm2vec_qwen2vl.jinja",
)
merged_path = str(
Path(HF_HUB_CACHE) / ("models--" + model_id.replace("/", "--") + "-vllm")
)
print(f"Saving merged model to {merged_path}...")
print(
"NOTE: This directory is not tracked by `huggingface_hub` "
"so you have to delete this manually if you don't want it anymore."
)
model.save_pretrained(merged_path)
processor.save_pretrained(merged_path)
print("Done!")
prompt, image = _get_vlm2vec_prompt_image(query, "<|image_pad|>")
engine_args = EngineArgs( engine_args = EngineArgs(
model="jinaai/jina-reranker-m0", model=merged_path,
runner="pooling", runner="pooling",
max_model_len=32768, max_model_len=4096,
trust_remote_code=True,
mm_processor_kwargs={ mm_processor_kwargs={
"min_pixels": 3136, "min_pixels": 3136,
"max_pixels": 602112, "max_pixels": 12845056,
}, },
limit_mm_per_prompt={"image": 1}, limit_mm_per_prompt={"image": 1},
) )
return ModelRequestData( return ModelRequestData(
engine_args=engine_args, engine_args=engine_args,
query=query["text"], prompt=prompt,
documents=query["image"], image=image,
) )
...@@ -186,7 +305,7 @@ def get_query(modality: QueryModality): ...@@ -186,7 +305,7 @@ def get_query(modality: QueryModality):
raise ValueError(msg) raise ValueError(msg)
def run_encode(model: str, modality: QueryModality, seed: Optional[int]): def run_encode(model: str, modality: QueryModality, seed: int | None):
query = get_query(modality) query = get_query(modality)
req_data = model_example_map[model](query) req_data = model_example_map[model](query)
...@@ -216,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -216,7 +335,7 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]):
print("-" * 50) print("-" * 50)
def run_score(model: str, modality: QueryModality, seed: Optional[int]): def run_score(model: str, modality: QueryModality, seed: int | None):
query = get_query(modality) query = get_query(modality)
req_data = model_example_map[model](query) req_data = model_example_map[model](query)
...@@ -231,9 +350,12 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]): ...@@ -231,9 +350,12 @@ def run_score(model: str, modality: QueryModality, seed: Optional[int]):
model_example_map = { model_example_map = {
"clip": run_clip,
"e5_v": run_e5_v, "e5_v": run_e5_v,
"vlm2vec": run_vlm2vec,
"jinavl_reranker": run_jinavl_reranker, "jinavl_reranker": run_jinavl_reranker,
"siglip": run_siglip,
"vlm2vec_phi3v": run_vlm2vec_phi3v,
"vlm2vec_qwen2vl": run_vlm2vec_qwen2vl,
} }
...@@ -246,7 +368,7 @@ def parse_args(): ...@@ -246,7 +368,7 @@ def parse_args():
"--model-name", "--model-name",
"-m", "-m",
type=str, type=str,
default="vlm2vec", default="vlm2vec_phi3v",
choices=model_example_map.keys(), choices=model_example_map.keys(),
help="The name of the embedding model.", help="The name of the embedding model.",
) )
......
...@@ -19,3 +19,15 @@ This directory contains a Helm chart for deploying the vllm application. The cha ...@@ -19,3 +19,15 @@ This directory contains a Helm chart for deploying the vllm application. The cha
- templates/pvc.yaml: Template for Persistent Volume Claims. - templates/pvc.yaml: Template for Persistent Volume Claims.
- templates/secrets.yaml: Template for Kubernetes Secrets. - templates/secrets.yaml: Template for Kubernetes Secrets.
- templates/service.yaml: Template for creating Services. - templates/service.yaml: Template for creating Services.
## Running Tests
This chart includes unit tests using [helm-unittest](https://github.com/helm-unittest/helm-unittest). Install the plugin and run tests:
```bash
# Install plugin
helm plugin install https://github.com/helm-unittest/helm-unittest
# Run tests
helm unittest .
```
...@@ -123,9 +123,6 @@ runAsUser: ...@@ -123,9 +123,6 @@ runAsUser:
{{- end }} {{- end }}
{{- end }} {{- end }}
{{- define "chart.extraInitImage" -}}
"amazon/aws-cli:2.6.4"
{{- end }}
{{- define "chart.extraInitEnv" -}} {{- define "chart.extraInitEnv" -}}
- name: S3_ENDPOINT_URL - name: S3_ENDPOINT_URL
...@@ -148,11 +145,15 @@ runAsUser: ...@@ -148,11 +145,15 @@ runAsUser:
secretKeyRef: secretKeyRef:
name: {{ .Release.Name }}-secrets name: {{ .Release.Name }}-secrets
key: s3accesskey key: s3accesskey
{{- if .Values.extraInit.s3modelpath }}
- name: S3_PATH - name: S3_PATH
value: "{{ .Values.extraInit.s3modelpath }}" value: "{{ .Values.extraInit.s3modelpath }}"
{{- end }}
{{- if hasKey .Values.extraInit "awsEc2MetadataDisabled" }}
- name: AWS_EC2_METADATA_DISABLED - name: AWS_EC2_METADATA_DISABLED
value: "{{ .Values.extraInit.awsEc2MetadataDisabled }}" value: "{{ .Values.extraInit.awsEc2MetadataDisabled }}"
{{- end }} {{- end }}
{{- end }}
{{/* {{/*
Define chart labels Define chart labels
......
...@@ -72,16 +72,21 @@ spec: ...@@ -72,16 +72,21 @@ spec:
{{ toYaml . | nindent 8 }} {{ toYaml . | nindent 8 }}
{{- end }} {{- end }}
{{- if .Values.extraInit }} {{- if and .Values.extraInit (or .Values.extraInit.modelDownload.enabled .Values.extraInit.initContainers) }}
initContainers: initContainers:
{{- if .Values.extraInit.modelDownload.enabled }}
- name: wait-download-model - name: wait-download-model
image: {{ include "chart.extraInitImage" . }} image: {{ .Values.extraInit.modelDownload.image.repository }}:{{ .Values.extraInit.modelDownload.image.tag }}
command: imagePullPolicy: {{ .Values.extraInit.modelDownload.image.pullPolicy }}
- /bin/bash command: {{ .Values.extraInit.modelDownload.waitContainer.command | toJson }}
args: args:
- -eucx {{- toYaml .Values.extraInit.modelDownload.waitContainer.args | nindent 10 }}
- while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done env:
env: {{- include "chart.extraInitEnv" . | nindent 10 }} {{- if .Values.extraInit.modelDownload.waitContainer.env }}
{{- toYaml .Values.extraInit.modelDownload.waitContainer.env | nindent 10 }}
{{- else }}
{{- include "chart.extraInitEnv" . | nindent 10 }}
{{- end }}
resources: resources:
requests: requests:
cpu: 200m cpu: 200m
...@@ -93,6 +98,10 @@ spec: ...@@ -93,6 +98,10 @@ spec:
- name: {{ .Release.Name }}-storage - name: {{ .Release.Name }}-storage
mountPath: /data mountPath: /data
{{- end }} {{- end }}
{{- with .Values.extraInit.initContainers }}
{{- toYaml . | nindent 6 }}
{{- end }}
{{- end }}
volumes: volumes:
- name: {{ .Release.Name }}-storage - name: {{ .Release.Name }}-storage
persistentVolumeClaim: persistentVolumeClaim:
......
{{- if .Values.extraInit }} {{- if and .Values.extraInit .Values.extraInit.modelDownload.enabled }}
apiVersion: batch/v1 apiVersion: batch/v1
kind: Job kind: Job
metadata: metadata:
...@@ -12,13 +12,17 @@ spec: ...@@ -12,13 +12,17 @@ spec:
spec: spec:
containers: containers:
- name: job-download-model - name: job-download-model
image: {{ include "chart.extraInitImage" . }} image: {{ .Values.extraInit.modelDownload.image.repository }}:{{ .Values.extraInit.modelDownload.image.tag }}
command: imagePullPolicy: {{ .Values.extraInit.modelDownload.image.pullPolicy }}
- /bin/bash command: {{ .Values.extraInit.modelDownload.downloadJob.command | toJson }}
args: args:
- -eucx {{- toYaml .Values.extraInit.modelDownload.downloadJob.args | nindent 8 }}
- aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data env:
env: {{- include "chart.extraInitEnv" . | nindent 8 }} {{- if .Values.extraInit.modelDownload.downloadJob.env }}
{{- toYaml .Values.extraInit.modelDownload.downloadJob.env | nindent 8 }}
{{- else }}
{{- include "chart.extraInitEnv" . | nindent 8 }}
{{- end }}
volumeMounts: volumeMounts:
- name: {{ .Release.Name }}-storage - name: {{ .Release.Name }}-storage
mountPath: /data mountPath: /data
......
suite: test deployment
templates:
- deployment.yaml
tests:
- it: should create wait-download-model init container when modelDownload is enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
initContainers: [ ]
pvcStorage: "1Gi"
s3modelpath: "relative_s3_model_path/opt-125m"
awsEc2MetadataDisabled: true
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- isNotEmpty:
path: spec.template.spec.initContainers
- equal:
path: spec.template.spec.initContainers[0].name
value: wait-download-model
- equal:
path: spec.template.spec.initContainers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.initContainers[0].imagePullPolicy
value: IfNotPresent
- it: should only create custom init containers when modelDownload is disabled
set:
extraInit:
modelDownload:
enabled: false
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "echo test" ]
downloadJob:
command: [ "/bin/bash" ]
args: [ "-c", "echo test" ]
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
imagePullPolicy: IfNotPresent
ports:
- containerPort: 8080
name: proxy
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- lengthEqual:
path: spec.template.spec.initContainers
count: 1
- equal:
path: spec.template.spec.initContainers[0].name
value: llm-d-routing-proxy
- equal:
path: spec.template.spec.initContainers[0].image
value: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
- equal:
path: spec.template.spec.initContainers[0].ports[0].containerPort
value: 8080
- it: should create both wait-download-model and custom init containers when both are enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
imagePullPolicy: IfNotPresent
ports:
- containerPort: 8080
name: proxy
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: Deployment
- lengthEqual:
path: spec.template.spec.initContainers
count: 2
- equal:
path: spec.template.spec.initContainers[0].name
value: wait-download-model
- equal:
path: spec.template.spec.initContainers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.initContainers[1].name
value: llm-d-routing-proxy
- equal:
path: spec.template.spec.initContainers[1].image
value: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
- equal:
path: spec.template.spec.initContainers[1].ports[0].containerPort
value: 8080
\ No newline at end of file
suite: test job
templates:
- job.yaml
tests:
- it: should create job when modelDownload is enabled
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "wait" ]
downloadJob:
command: [ "/bin/bash" ]
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
pvcStorage: "1Gi"
s3modelpath: "relative_s3_model_path/opt-125m"
awsEc2MetadataDisabled: true
asserts:
- hasDocuments:
count: 1
- isKind:
of: Job
- equal:
path: spec.template.spec.containers[0].name
value: job-download-model
- equal:
path: spec.template.spec.containers[0].image
value: amazon/aws-cli:2.6.4
- equal:
path: spec.template.spec.restartPolicy
value: OnFailure
- it: should not create job when modelDownload is disabled
set:
extraInit:
modelDownload:
enabled: false
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: [ "/bin/bash" ]
args: [ "-c", "wait" ]
downloadJob:
command: [ "/bin/bash" ]
args: [ "-c", "download" ]
initContainers:
- name: llm-d-routing-proxy
image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 0
suite: test pvc
templates:
- pvc.yaml
tests:
# Test Case: PVC Created When extraInit Defined
- it: should create pvc when extraInit is defined
set:
extraInit:
modelDownload:
enabled: true
image:
repository: "amazon/aws-cli"
tag: "2.6.4"
pullPolicy: "IfNotPresent"
waitContainer:
command: ["/bin/bash"]
args: ["-c", "wait"]
downloadJob:
command: ["/bin/bash"]
args: ["-c", "download"]
pvcStorage: "10Gi"
asserts:
- hasDocuments:
count: 1
- isKind:
of: PersistentVolumeClaim
- equal:
path: spec.accessModes[0]
value: ReadWriteOnce
- equal:
path: spec.resources.requests.storage
value: 10Gi
\ No newline at end of file
...@@ -136,6 +136,70 @@ ...@@ -136,6 +136,70 @@
"extraInit": { "extraInit": {
"type": "object", "type": "object",
"properties": { "properties": {
"modelDownload": {
"type": "object",
"properties": {
"enabled": {
"type": "boolean"
},
"image": {
"type": "object",
"properties": {
"repository": {
"type": "string"
},
"tag": {
"type": "string"
},
"pullPolicy": {
"type": "string"
}
},
"required": ["repository", "tag", "pullPolicy"]
},
"waitContainer": {
"type": "object",
"properties": {
"command": {
"type": "array",
"items": {"type": "string"}
},
"args": {
"type": "array",
"items": {"type": "string"}
},
"env": {
"type": "array",
"items": {"type": "object"}
}
},
"required": ["command", "args"]
},
"downloadJob": {
"type": "object",
"properties": {
"command": {
"type": "array",
"items": {"type": "string"}
},
"args": {
"type": "array",
"items": {"type": "string"}
},
"env": {
"type": "array",
"items": {"type": "object"}
}
},
"required": ["command", "args"]
}
},
"required": ["enabled", "image", "waitContainer", "downloadJob"]
},
"initContainers": {
"type": "array",
"items": {"type": "object"}
},
"s3modelpath": { "s3modelpath": {
"type": "string" "type": "string"
}, },
...@@ -147,9 +211,9 @@ ...@@ -147,9 +211,9 @@
} }
}, },
"required": [ "required": [
"pvcStorage", "modelDownload",
"s3modelpath", "initContainers",
"awsEc2MetadataDisabled" "pvcStorage"
] ]
}, },
"extraContainers": { "extraContainers": {
......
...@@ -75,10 +75,65 @@ maxUnavailablePodDisruptionBudget: "" ...@@ -75,10 +75,65 @@ maxUnavailablePodDisruptionBudget: ""
# -- Additional configuration for the init container # -- Additional configuration for the init container
extraInit: extraInit:
# -- Path of the model on the s3 which hosts model weights and config files # -- Model download functionality (optional)
modelDownload:
# -- Enable model download job and wait container
enabled: true
# -- Image configuration for model download operations
image:
# -- Image repository
repository: "amazon/aws-cli"
# -- Image tag
tag: "2.6.4"
# -- Image pull policy
pullPolicy: "IfNotPresent"
# -- Wait container configuration (init container that waits for model to be ready)
waitContainer:
# -- Command to execute
command: ["/bin/bash"]
# -- Arguments for the wait container
args:
- "-eucx"
- "while aws --endpoint-url $S3_ENDPOINT_URL s3 sync --dryrun s3://$S3_BUCKET_NAME/$S3_PATH /data | grep -q download; do sleep 10; done"
# -- Environment variables (optional, overrides S3 defaults entirely if specified)
# env:
# - name: HUGGING_FACE_HUB_TOKEN
# value: "your-token"
# - name: MODEL_ID
# value: "meta-llama/Llama-2-7b"
# -- Download job configuration (job that actually downloads the model)
downloadJob:
# -- Command to execute
command: ["/bin/bash"]
# -- Arguments for the download job
args:
- "-eucx"
- "aws --endpoint-url $S3_ENDPOINT_URL s3 sync s3://$S3_BUCKET_NAME/$S3_PATH /data"
# -- Environment variables (optional, overrides S3 defaults entirely if specified)
# env:
# - name: HUGGING_FACE_HUB_TOKEN
# value: "your-token"
# - name: MODEL_ID
# value: "meta-llama/Llama-2-7b"
# -- Custom init containers (appended after wait-download-model if modelDownload is enabled)
initContainers: []
# Example for llm-d sidecar:
# initContainers:
# - name: llm-d-routing-proxy
# image: ghcr.io/llm-d/llm-d-routing-sidecar:v0.2.0
# imagePullPolicy: IfNotPresent
# ports:
# - containerPort: 8080
# name: proxy
# securityContext:
# runAsUser: 1000
# -- Path of the model on the s3 which hosts model weights and config files
s3modelpath: "relative_s3_model_path/opt-125m" s3modelpath: "relative_s3_model_path/opt-125m"
# -- Storage size of the s3 # -- Storage size for the PVC
pvcStorage: "1Gi" pvcStorage: "1Gi"
# -- Disable AWS EC2 metadata service
awsEc2MetadataDisabled: true awsEc2MetadataDisabled: true
# -- Additional containers configuration # -- Additional containers configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment