Unverified Commit 755356b3 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: expose media_io_kwargs at runtime (#34778)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent 58928475
...@@ -35,6 +35,8 @@ def server(): ...@@ -35,6 +35,8 @@ def server():
"--trust-remote-code", "--trust-remote-code",
"--limit-mm-per-prompt", "--limit-mm-per-prompt",
json.dumps({"video": MAXIMUM_VIDEOS}), json.dumps({"video": MAXIMUM_VIDEOS}),
"--media-io-kwargs",
json.dumps({"video": {"num_frames": 32}}),
] ]
# ROCm: Increase timeouts to handle potential network delays and slower # ROCm: Increase timeouts to handle potential network delays and slower
...@@ -127,6 +129,73 @@ async def test_single_chat_session_video( ...@@ -127,6 +129,73 @@ async def test_single_chat_session_video(
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
async def test_request_media_io_kwargs_override_uses_fewer_video_frames(
client: openai.AsyncOpenAI, model_name: str, video_url: str
):
messages = dummy_messages_from_video_url(video_url)
default_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
)
override_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
extra_body={
"media_io_kwargs": {
"video": {
"num_frames": 4,
}
}
},
)
assert default_resp.usage is not None
assert override_resp.usage is not None
assert override_resp.usage.prompt_tokens < default_resp.usage.prompt_tokens
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", [TEST_VIDEO_URLS[0]])
async def test_invalid_num_frames_request_recoverable(
client: openai.AsyncOpenAI, model_name: str, video_url: str
):
messages = dummy_messages_from_video_url(video_url)
with pytest.raises((openai.BadRequestError, openai.APIStatusError)):
await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
extra_body={
"media_io_kwargs": {
"video": {
"num_frames": "invalid",
}
}
},
)
# Server should still handle subsequent requests after the failed one.
recovery_resp = await client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=1,
temperature=0.0,
)
recovery_msg = recovery_resp.choices[0].message
assert recovery_msg.content is not None and len(recovery_msg.content) >= 0
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
......
...@@ -127,6 +127,39 @@ def test_chat_image_base64_request(server: RemoteOpenAIServer, model_name: str): ...@@ -127,6 +127,39 @@ def test_chat_image_base64_request(server: RemoteOpenAIServer, model_name: str):
assert output.usage.prompt_tokens == 767 assert output.usage.prompt_tokens == 767
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_chat_image_with_media_io_kwargs(server: RemoteOpenAIServer, model_name: str):
rgba_image_url = (
"https://vllm-public-assets.s3.us-west-2.amazonaws.com"
"/vision_model_images/RGBA_comp.png"
)
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Represent the user's input."},
{"type": "image_url", "image_url": {"url": rgba_image_url}},
],
}
]
response = requests.post(
server.url_for("v1/embeddings"),
json={
"model": model_name,
"messages": messages,
"media_io_kwargs": {
"image": {"rgba_background_color": [0, 0, 0]},
},
},
)
response.raise_for_status()
output = EmbeddingResponse.model_validate(response.json())
assert len(output.data) == 1
assert len(output.data[0].embedding) == 3072
def get_hf_prompt_tokens(model_name, content, image_url): def get_hf_prompt_tokens(model_name, content, image_url):
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
model_name, trust_remote_code=True, num_crops=4 model_name, trust_remote_code=True, num_crops=4
......
...@@ -462,10 +462,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -462,10 +462,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
maximum per prompt. maximum per prompt.
""" """
def __init__(self, model_config: ModelConfig): def __init__(
self,
model_config: ModelConfig,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
):
super().__init__() super().__init__()
self._model_config = model_config self._model_config = model_config
self._media_io_kwargs = media_io_kwargs
self._items_by_modality = defaultdict[str, list[_T]](list) self._items_by_modality = defaultdict[str, list[_T]](list)
# Track original modality for each vision_chunk item (image or video) # Track original modality for each vision_chunk item (image or video)
...@@ -487,6 +492,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -487,6 +492,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
model_cls = get_model_cls(self.model_config) model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls) return cast(type[SupportsMultiModal], model_cls)
@property
def media_io_kwargs(self) -> dict[str, dict[str, Any]] | None:
return self._media_io_kwargs or (
self._model_config.multimodal_config.media_io_kwargs
if self._model_config.multimodal_config
else None
)
@property @property
def allowed_local_media_path(self): def allowed_local_media_path(self):
return self._model_config.allowed_local_media_path return self._model_config.allowed_local_media_path
...@@ -769,12 +782,10 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -769,12 +782,10 @@ class MultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR, envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=tracker.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
...@@ -881,11 +892,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -881,11 +892,9 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load( self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR, envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs, media_io_kwargs=tracker.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains, allowed_media_domains=tracker.allowed_media_domains,
) )
...@@ -1530,13 +1539,14 @@ def parse_chat_messages( ...@@ -1530,13 +1539,14 @@ def parse_chat_messages(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
content_format: ChatTemplateContentFormat, content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
MultiModalDataDict | None, MultiModalDataDict | None,
MultiModalUUIDDict | None, MultiModalUUIDDict | None,
]: ]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config) mm_tracker = MultiModalItemTracker(model_config, media_io_kwargs=media_io_kwargs)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content( sub_messages = _parse_chat_message_content(
...@@ -1563,13 +1573,16 @@ async def parse_chat_messages_async( ...@@ -1563,13 +1573,16 @@ async def parse_chat_messages_async(
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
model_config: ModelConfig, model_config: ModelConfig,
content_format: ChatTemplateContentFormat, content_format: ChatTemplateContentFormat,
media_io_kwargs: dict[str, dict[str, Any]] | None = None,
) -> tuple[ ) -> tuple[
list[ConversationMessage], list[ConversationMessage],
MultiModalDataDict | None, MultiModalDataDict | None,
MultiModalUUIDDict | None, MultiModalUUIDDict | None,
]: ]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config) mm_tracker = AsyncMultiModalItemTracker(
model_config, media_io_kwargs=media_io_kwargs
)
for msg in messages: for msg in messages:
sub_messages = _parse_chat_message_content( sub_messages = _parse_chat_message_content(
......
...@@ -268,6 +268,13 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -268,6 +268,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
"Will be accessible by the chat template." "Will be accessible by the chat template."
), ),
) )
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
...@@ -366,6 +373,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -366,6 +373,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
reasoning_effort=self.reasoning_effort, reasoning_effort=self.reasoning_effort,
), ),
), ),
media_io_kwargs=self.media_io_kwargs,
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
......
...@@ -900,10 +900,15 @@ class OpenAIServing: ...@@ -900,10 +900,15 @@ class OpenAIServing:
), ),
) )
mm_config = self.model_config.multimodal_config
tok_params = request.build_tok_params(self.model_config) tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params( chat_params = request.build_chat_params(
default_template, default_template_content_format default_template, default_template_content_format
).with_defaults(default_template_kwargs) ).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)
(conversation,), (engine_prompt,) = await renderer.render_chat_async( (conversation,), (engine_prompt,) = await renderer.render_chat_async(
[messages], [messages],
......
...@@ -197,6 +197,13 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -197,6 +197,13 @@ class ResponsesRequest(OpenAIBaseModel):
"through out the inference process and return in response." "through out the inference process and return in response."
), ),
) )
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description=("Additional kwargs to pass to the HF processor."), description=("Additional kwargs to pass to the HF processor."),
...@@ -276,6 +283,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -276,6 +283,7 @@ class ResponsesRequest(OpenAIBaseModel):
reasoning_effort=None if reasoning is None else reasoning.effort, reasoning_effort=None if reasoning is None else reasoning.effort,
), ),
), ),
media_io_kwargs=self.media_io_kwargs,
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
......
...@@ -123,10 +123,15 @@ class PoolingIOProcessor: ...@@ -123,10 +123,15 @@ class PoolingIOProcessor:
), ),
) )
mm_config = self.model_config.multimodal_config
tok_params = request.build_tok_params(self.model_config) tok_params = request.build_tok_params(self.model_config)
chat_params = request.build_chat_params( chat_params = request.build_chat_params(
default_template, default_template_content_format default_template, default_template_content_format
).with_defaults(default_template_kwargs) ).with_defaults(
default_template_kwargs,
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
)
(conversation,), (engine_prompt,) = renderer.render_chat( (conversation,), (engine_prompt,) = renderer.render_chat(
[messages], [messages],
......
...@@ -124,6 +124,13 @@ class ChatRequestMixin(OpenAIBaseModel): ...@@ -124,6 +124,13 @@ class ChatRequestMixin(OpenAIBaseModel):
"Will be accessible by the chat template." "Will be accessible by the chat template."
), ),
) )
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
# --8<-- [end:chat-extra-params] # --8<-- [end:chat-extra-params]
@model_validator(mode="before") @model_validator(mode="before")
...@@ -151,6 +158,7 @@ class ChatRequestMixin(OpenAIBaseModel): ...@@ -151,6 +158,7 @@ class ChatRequestMixin(OpenAIBaseModel):
continue_final_message=self.continue_final_message, continue_final_message=self.continue_final_message,
), ),
), ),
media_io_kwargs=self.media_io_kwargs,
) )
......
...@@ -100,6 +100,13 @@ class TokenizeChatRequest(OpenAIBaseModel): ...@@ -100,6 +100,13 @@ class TokenizeChatRequest(OpenAIBaseModel):
"Will be accessible by the chat template." "Will be accessible by the chat template."
), ),
) )
media_io_kwargs: dict[str, dict[str, Any]] | None = Field(
default=None,
description=(
"Additional kwargs to pass to the media IO connectors, "
"keyed by modality. Merged with engine-level media_io_kwargs."
),
)
mm_processor_kwargs: dict[str, Any] | None = Field( mm_processor_kwargs: dict[str, Any] | None = Field(
default=None, default=None,
description="Additional kwargs to pass to the HF processor.", description="Additional kwargs to pass to the HF processor.",
...@@ -134,6 +141,7 @@ class TokenizeChatRequest(OpenAIBaseModel): ...@@ -134,6 +141,7 @@ class TokenizeChatRequest(OpenAIBaseModel):
continue_final_message=self.continue_final_message, continue_final_message=self.continue_final_message,
), ),
), ),
media_io_kwargs=self.media_io_kwargs,
) )
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
......
...@@ -83,11 +83,17 @@ def extract_audio_from_video_bytes( ...@@ -83,11 +83,17 @@ def extract_audio_from_video_bytes(
class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__() super().__init__()
# `kwargs` contains custom arguments from # `kwargs` contains custom arguments from
# --media-io-kwargs for this modality. # --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying # They can be passed to the underlying
# media loaders (e.g. custom implementations) # media loaders (e.g. custom implementations)
# for flexible control. # for flexible control.
...@@ -122,6 +128,11 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]): ...@@ -122,6 +128,11 @@ class AudioMediaIO(MediaIO[tuple[npt.NDArray, float]]):
class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]): class AudioEmbeddingMediaIO(MediaIO[torch.Tensor]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
......
...@@ -44,6 +44,28 @@ class MediaWithBytes(Generic[_T]): ...@@ -44,6 +44,28 @@ class MediaWithBytes(Generic[_T]):
class MediaIO(ABC, Generic[_T]): class MediaIO(ABC, Generic[_T]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
@classmethod
def merge_kwargs(
cls,
default_kwargs: dict[str, Any] | None,
runtime_kwargs: dict[str, Any] | None,
) -> dict[str, Any]:
"""Merge config-level kwargs and request-level kwargs.
By default this performs a shallow merge where runtime kwargs override
keys in default kwargs. Subclasses may override to apply modality-
specific behavior.
"""
merged = dict(default_kwargs or {})
if runtime_kwargs:
merged.update(runtime_kwargs)
return merged
@abstractmethod @abstractmethod
def load_bytes(self, data: bytes) -> _T: def load_bytes(self, data: bytes) -> _T:
raise NotImplementedError raise NotImplementedError
......
...@@ -32,9 +32,43 @@ atexit.register(global_thread_pool.shutdown) ...@@ -32,9 +32,43 @@ atexit.register(global_thread_pool.shutdown)
MEDIA_CONNECTOR_REGISTRY = ExtensionManager() MEDIA_CONNECTOR_REGISTRY = ExtensionManager()
MODALITY_IO_MAP: dict[str, type[MediaIO]] = {
"audio": AudioMediaIO,
"image": ImageMediaIO,
"video": VideoMediaIO,
}
def merge_media_io_kwargs(
defaults: dict[str, dict[str, Any]] | None,
overrides: dict[str, dict[str, Any]] | None,
) -> dict[str, dict[str, Any]] | None:
"""Merge config-level and per-request media_io_kwargs per modality.
Each modality key is merged using the corresponding MediaIO subclass's
``merge_kwargs``, which may apply modality-specific logic (e.g.
VideoMediaIO clears cross-dependent fps/num_frames fields).
"""
if not defaults and not overrides:
return None
all_keys = set(defaults or {}) | set(overrides or {})
merged = {}
for key in all_keys:
io_cls = MODALITY_IO_MAP.get(key, MediaIO)
merged[key] = io_cls.merge_kwargs(
(defaults or {}).get(key),
(overrides or {}).get(key),
)
return merged or None
@MEDIA_CONNECTOR_REGISTRY.register("http") @MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector: class MediaConnector:
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__( def __init__(
self, self,
media_io_kwargs: dict[str, dict[str, Any]] | None = None, media_io_kwargs: dict[str, dict[str, Any]] | None = None,
......
...@@ -15,12 +15,18 @@ from .base import MediaIO, MediaWithBytes ...@@ -15,12 +15,18 @@ from .base import MediaIO, MediaWithBytes
class ImageMediaIO(MediaIO[Image.Image]): class ImageMediaIO(MediaIO[Image.Image]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self, image_mode: str = "RGB", **kwargs) -> None: def __init__(self, image_mode: str = "RGB", **kwargs) -> None:
super().__init__() super().__init__()
self.image_mode = image_mode self.image_mode = image_mode
# `kwargs` contains custom arguments from # `kwargs` contains custom arguments from
# --media-io-kwargs for this modality. # --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying # They can be passed to the underlying
# media loaders (e.g. custom implementations) # media loaders (e.g. custom implementations)
# for flexible control. # for flexible control.
...@@ -88,6 +94,13 @@ class ImageMediaIO(MediaIO[Image.Image]): ...@@ -88,6 +94,13 @@ class ImageMediaIO(MediaIO[Image.Image]):
class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]): class ImageEmbeddingMediaIO(MediaIO[torch.Tensor]):
"""Image embedding MediaIO implementation.
Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
......
...@@ -17,6 +17,28 @@ from .image import ImageMediaIO ...@@ -17,6 +17,28 @@ from .image import ImageMediaIO
class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
"""Configuration values can be user-provided either by --media-io-kwargs or
by the runtime API field "media_io_kwargs". Ensure proper validation and
error handling.
"""
@classmethod
def merge_kwargs(
cls,
default_kwargs: dict[str, Any] | None,
runtime_kwargs: dict[str, Any] | None,
) -> dict[str, Any]:
merged = super().merge_kwargs(default_kwargs, runtime_kwargs)
# fps and num_frames interact with each other, so if either is
# overridden at request time, wipe the other from defaults to
# avoid unintuitive cross-field interactions.
if runtime_kwargs:
if "num_frames" in runtime_kwargs and "fps" not in runtime_kwargs:
merged.pop("fps", None)
elif "fps" in runtime_kwargs and "num_frames" not in runtime_kwargs:
merged.pop("num_frames", None)
return merged
def __init__( def __init__(
self, self,
image_io: ImageMediaIO, image_io: ImageMediaIO,
...@@ -28,7 +50,8 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]): ...@@ -28,7 +50,8 @@ class VideoMediaIO(MediaIO[tuple[npt.NDArray, dict[str, Any]]]):
self.image_io = image_io self.image_io = image_io
self.num_frames = num_frames self.num_frames = num_frames
# `kwargs` contains custom arguments from # `kwargs` contains custom arguments from
# --media-io-kwargs for this modality. # --media-io-kwargs for this modality, merged with
# per-request runtime media_io_kwargs via merge_kwargs().
# They can be passed to the underlying # They can be passed to the underlying
# media loaders (e.g. custom implementations) # media loaders (e.g. custom implementations)
# for flexible control. # for flexible control.
......
...@@ -49,6 +49,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]): ...@@ -49,6 +49,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
...@@ -75,6 +76,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]): ...@@ -75,6 +76,7 @@ class DeepseekV32Renderer(BaseRenderer[DeepseekV32Tokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
......
...@@ -49,6 +49,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]): ...@@ -49,6 +49,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
...@@ -75,6 +76,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]): ...@@ -75,6 +76,7 @@ class Grok2Renderer(BaseRenderer[Grok2Tokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
......
...@@ -635,6 +635,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]): ...@@ -635,6 +635,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
tokenizer=tokenizer, tokenizer=tokenizer,
model_config=model_config, model_config=model_config,
), ),
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(
...@@ -689,6 +690,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]): ...@@ -689,6 +690,7 @@ class HfRenderer(BaseRenderer[HfTokenizer]):
tokenizer=tokenizer, tokenizer=tokenizer,
model_config=model_config, model_config=model_config,
), ),
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(
......
...@@ -90,6 +90,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]): ...@@ -90,6 +90,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = safe_apply_chat_template( prompt_raw = safe_apply_chat_template(
...@@ -116,6 +117,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]): ...@@ -116,6 +117,7 @@ class MistralRenderer(BaseRenderer[MistralTokenizer]):
messages, messages,
self.model_config, self.model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt_raw = await self._apply_chat_template_async( prompt_raw = await self._apply_chat_template_async(
......
...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, TypeVar ...@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, TypeVar
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.media.connector import merge_media_io_kwargs
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
...@@ -52,8 +53,15 @@ class ChatParams: ...@@ -52,8 +53,15 @@ class ChatParams:
chat_template_kwargs: dict[str, Any] = field(default_factory=dict) chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
"""The kwargs to pass to the chat template.""" """The kwargs to pass to the chat template."""
def with_defaults(self, default_chat_template_kwargs: dict[str, Any] | None): media_io_kwargs: dict[str, dict[str, Any]] | None = None
if not default_chat_template_kwargs: """Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""
def with_defaults(
self,
default_chat_template_kwargs: dict[str, Any] | None = None,
default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
):
if not default_chat_template_kwargs and not default_media_io_kwargs:
return self return self
return ChatParams( return ChatParams(
...@@ -63,6 +71,10 @@ class ChatParams: ...@@ -63,6 +71,10 @@ class ChatParams:
default_chat_template_kwargs, default_chat_template_kwargs,
self.chat_template_kwargs, self.chat_template_kwargs,
), ),
media_io_kwargs=merge_media_io_kwargs(
default_media_io_kwargs,
self.media_io_kwargs,
),
) )
def get_apply_chat_template_kwargs(self) -> dict[str, Any]: def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
......
...@@ -43,6 +43,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -43,6 +43,7 @@ class TerratorchRenderer(BaseRenderer):
messages, messages,
model_config, model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt = parse_dec_only_prompt([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs
...@@ -64,6 +65,7 @@ class TerratorchRenderer(BaseRenderer): ...@@ -64,6 +65,7 @@ class TerratorchRenderer(BaseRenderer):
messages, messages,
model_config, model_config,
content_format="string", content_format="string",
media_io_kwargs=params.media_io_kwargs,
) )
prompt = parse_dec_only_prompt([1]) # Dummy token IDs prompt = parse_dec_only_prompt([1]) # Dummy token IDs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment