Unverified Commit 4105de62 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

fix: reject multimodal requests when worker lacks --modality multimodal (#7065)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent b339c279
...@@ -550,6 +550,12 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -550,6 +550,12 @@ class HandlerBase(BaseGenerativeHandler):
processed_input["multi_modal_data"] = None processed_input["multi_modal_data"] = None
return processed_input return processed_input
if self.multimodal_processor is None and self._request_has_multimodal(request):
raise RuntimeError(
"Multimodal input received but worker started without --modality multimodal. "
"Restart the worker with --modality multimodal or remove image_url content."
)
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available # PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if self.multimodal_processor: if self.multimodal_processor:
processed_input = await self.multimodal_processor.process_openai_request( processed_input = await self.multimodal_processor.process_openai_request(
...@@ -570,6 +576,20 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -570,6 +576,20 @@ class HandlerBase(BaseGenerativeHandler):
# Fallback: text-only flow (no multimodal processor or no multimodal data) # Fallback: text-only flow (no multimodal processor or no multimodal data)
return request.get("token_ids") return request.get("token_ids")
def _request_has_multimodal(self, request: dict) -> bool:
if request.get("multi_modal_data"):
return True
extra_args = request.get("extra_args") or {}
messages = extra_args.get("messages") or request.get("messages") or []
for message in messages:
content = message.get("content", [])
if isinstance(content, list):
for part in content:
if isinstance(part, dict) and part.get("type") == "image_url":
return True
return False
def _normalize_request_format(self, request: dict) -> None: def _normalize_request_format(self, request: dict) -> None:
""" """
Convert OpenAI request format to TRT-LLM internal format. Convert OpenAI request format to TRT-LLM internal format.
......
...@@ -17,13 +17,14 @@ if not torch.cuda.is_available(): ...@@ -17,13 +17,14 @@ if not torch.cuda.is_available():
"CUDA/GPU not available, but tensorrt_llm import and the test require GPU.", "CUDA/GPU not available, but tensorrt_llm import and the test require GPU.",
allow_module_level=True, allow_module_level=True,
) )
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.request_handlers.handler_base import HandlerBase from dynamo.trtllm.request_handlers.handler_base import HandlerBase
pytestmark = [ pytestmark = [
pytest.mark.unit, pytest.mark.unit,
pytest.mark.trtllm, pytest.mark.trtllm,
pytest.mark.pre_merge, pytest.mark.pre_merge,
pytest.mark.gpu_0, pytest.mark.gpu_1,
] ]
...@@ -375,3 +376,67 @@ class TestHandleCancellationAbortToggle: ...@@ -375,3 +376,67 @@ class TestHandleCancellationAbortToggle:
await handler._handle_cancellation(generation_result, context) await handler._handle_cancellation(generation_result, context)
generation_result.abort.assert_not_called() generation_result.abort.assert_not_called()
class TestMultimodalGuard:
"""Tests for multimodal guard when --modality multimodal is not configured."""
IMAGE_MESSAGE = {
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "http://example.com/a.jpg"}},
{"type": "text", "text": "describe image"},
],
}
def _make_handler(self, multimodal_processor=None) -> HandlerBase:
config = MagicMock()
config.multimodal_processor = multimodal_processor
config.shutdown_event = None
return _ConcreteHandler(config)
async def _prepare(self, handler, request, epd_metadata=None):
return await handler._prepare_input_for_generation(
request=request,
embeddings=None,
ep_disaggregated_params=None,
epd_metadata=epd_metadata or {},
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"request_factory",
[
lambda msg: {"token_ids": [1, 2, 3], "extra_args": {"messages": [msg]}},
lambda msg: {"token_ids": [1, 2, 3], "messages": [msg]},
],
ids=["extra_args_messages", "top_level_messages"],
)
async def test_raises_for_image_url(self, request_factory):
handler = self._make_handler(multimodal_processor=None)
request = request_factory(self.IMAGE_MESSAGE)
with pytest.raises(RuntimeError, match="--modality multimodal"):
await self._prepare(handler, request)
@pytest.mark.asyncio
async def test_text_only_request_falls_back_to_token_ids(self):
handler = self._make_handler(multimodal_processor=None)
result = await self._prepare(handler, {"token_ids": [10, 20, 30]})
assert result == [10, 20, 30]
@pytest.mark.asyncio
async def test_decode_with_prefill_metadata_bypasses_guard(self):
handler = self._make_handler(multimodal_processor=None)
handler.disaggregation_mode = DisaggregationMode.DECODE
request = {"token_ids": [1, 2, 3], "messages": [self.IMAGE_MESSAGE]}
epd_metadata = {
"_prefill_prompt": "describe image",
"_prefill_prompt_token_ids": [1, 2, 3],
}
result = await self._prepare(handler, request, epd_metadata)
assert result["prompt"] == "describe image"
assert result["prompt_token_ids"] == [1, 2, 3]
assert result["multi_modal_data"] is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment