Unverified Commit da670d44 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

fix(vllm): eliminate redundant image re-download on decode worker in disagg multimodal (#7827)


Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent ffe20629
......@@ -52,7 +52,7 @@ from dynamo.runtime import Client
from dynamo.runtime.logging import configure_dynamo_logging
from .args import Config
from .constants import EmbeddingTransferMode
from .constants import DisaggregationMode, EmbeddingTransferMode
from .engine_monitor import VllmEngineMonitor
from .multimodal_utils.hash_utils import compute_mm_uuids_from_images
from .multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
......@@ -1591,17 +1591,46 @@ class DecodeWorkerHandler(BaseWorkerHandler):
kv_params = None
embedding_params = None
is_decode_only = self.config.disaggregation_mode == DisaggregationMode.DECODE
has_mm_data = (
"multi_modal_data" in request and request["multi_modal_data"] is not None
)
multi_modal_data = None
# The decode worker is handling disaggregated requests, the mm embedding will be synthetic
if prefill_result is not None and embedding_params is not None:
if is_decode_only:
# Decode mode: branch on model, not data.
if is_qwen_vl_model(self.config.model):
multi_modal_data = construct_qwen_decode_mm_data(
embedding_params["image_grid_thw"],
embedding_params["embeddings_shape"],
request_id,
)
# Qwen VL needs embedding_params for mRoPE initialization.
if embedding_params is not None:
multi_modal_data = construct_qwen_decode_mm_data(
embedding_params["image_grid_thw"],
embedding_params["embeddings_shape"],
request_id,
)
elif has_mm_data and request["multi_modal_data"].get(IMAGE_URL_KEY):
msg = (
"Decode worker received multimodal request without "
"prefill result"
if prefill_result is None
else "Prefill did not produce required multimodal "
"embedding metadata (image_grid_thw) for Qwen VL "
"decode. Use --route-to-encoder or the P/D launcher "
"with grid_thw computation support"
)
logger.error("Request %s: %s", request_id, msg)
yield {"status": "error", "message": msg}
return
# TODO(DIS-1661): video/audio re-downloaded on decode.
# TODO(DIS-1664): mixed image+video in disagg decode is not
# supported — synthetic image data would be overwritten.
if multi_modal_data is None and has_mm_data:
mm = request["multi_modal_data"]
if mm.get(VIDEO_URL_KEY) or mm.get("audio_url"):
multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
else:
# Extract and decode multimodal data if present
# Aggregated mode: load images normally
multi_modal_data = await self._extract_multimodal_data(
request, request_id, context
)
......
......@@ -39,6 +39,7 @@ def _make_config(
is_prefill_worker: bool = False,
enable_multimodal: bool = True,
multimodal_embedding_cache_capacity_gb: float = 0,
disaggregation_mode: str | None = None,
) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
from dynamo.vllm.constants import DisaggregationMode, EmbeddingTransferMode
......@@ -46,11 +47,12 @@ def _make_config(
config = MagicMock()
config.model = model
config.is_prefill_worker = is_prefill_worker
config.disaggregation_mode = (
DisaggregationMode.PREFILL
if is_prefill_worker
else DisaggregationMode.AGGREGATED
)
if disaggregation_mode is not None:
config.disaggregation_mode = getattr(DisaggregationMode, disaggregation_mode)
elif is_prefill_worker:
config.disaggregation_mode = DisaggregationMode.PREFILL
else:
config.disaggregation_mode = DisaggregationMode.AGGREGATED
# NIXL_WRITE / NIXL_READ modes require GPU, the tests may run in CPU-only environments,
# so set to LOCAL mode.
config.embedding_transfer_mode = EmbeddingTransferMode.LOCAL
......@@ -76,7 +78,7 @@ def _make_handler(
return mod.DecodeWorkerHandler(
runtime=MagicMock(),
config=config,
engine_client=MagicMock(),
engine=MagicMock(),
default_sampling_params={},
encode_worker_client=encode_worker_client,
)
......@@ -304,3 +306,206 @@ class TestGenerateDisagg:
assert isinstance(chunks[0], dict)
assert chunks[0]["token_ids"] == [42]
assert chunks[0]["finish_reason"] == "stop"
# ── Decode worker multimodal branching tests ───────────────────────
def _make_decode_handler(
model: str = "test-model",
disaggregation_mode: str = "DECODE",
) -> mod.DecodeWorkerHandler:
"""Construct a DecodeWorkerHandler with mocked internals."""
config = _make_config(model=model, disaggregation_mode=disaggregation_mode)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.DecodeWorkerHandler(
runtime=MagicMock(),
config=config,
engine=MagicMock(),
default_sampling_params={},
)
handler.config = config
handler.enable_multimodal = True
handler.image_loader = MagicMock()
handler.embedding_loader = None
handler.model_max_len = 4096
handler.default_sampling_params = {}
handler.kv_event_publisher = None
handler.otel_tracing_enabled = False
handler.input_param_manager = MagicMock()
handler.input_param_manager.get_extra_params.return_value = {}
return handler
@pytest.mark.asyncio(loop_scope="function")
class TestDecodeWorkerMultimodalBranching:
"""Tests for the mode-aware multimodal branching in _generate_token_mode."""
async def test_decode_only_qwen_with_mm_data_no_prefill_result_errors(self):
"""Decode-only Qwen worker receiving mm request without prefill_result -> error."""
handler = _make_decode_handler(
model="Qwen/Qwen3-VL-2B-Instruct",
disaggregation_mode="DECODE",
)
request = {
"token_ids": [1, 2, 3],
"multi_modal_data": {"image_url": [{"Url": "http://img.png"}]},
"sampling_options": {},
"stop_conditions": {},
"output_options": {},
}
context = MagicMock()
chunks = []
async for chunk in handler._generate_token_mode(request, context, "req-1"):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0]["status"] == "error"
assert "without prefill result" in chunks[0]["message"]
async def test_decode_only_qwen_missing_embedding_params_errors(self):
"""Decode-only Qwen VL with prefill_result but no embedding_params -> error."""
handler = _make_decode_handler(
model="Qwen/Qwen3-VL-2B-Instruct",
disaggregation_mode="DECODE",
)
request = {
"token_ids": [1, 2, 3],
"multi_modal_data": {"image_url": [{"Url": "http://img.png"}]},
"sampling_options": {},
"stop_conditions": {},
"output_options": {},
"prefill_result": {
"disaggregated_params": {
"kv_transfer_params": {"block_ids": [0]},
# embedding_params intentionally missing
},
},
}
context = MagicMock()
chunks = []
async for chunk in handler._generate_token_mode(request, context, "req-1"):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0]["status"] == "error"
assert "embedding metadata" in chunks[0]["message"]
async def test_decode_only_non_qwen_proceeds_without_embedding_params(self):
"""Decode-only non-Qwen with prefill_result but no embedding_params -> proceeds.
Non-Qwen models don't need embedding_params — the KV cache from
prefill already contains the vision context.
"""
handler = _make_decode_handler(
model="llava-hf/llava-1.5-7b-hf",
disaggregation_mode="DECODE",
)
# Return an error from _build_prompt_from_request so we don't need
# to mock the full engine — just verify we get past the decode guard.
handler._build_prompt_from_request = MagicMock(
return_value=(None, None, {"status": "error", "message": "test stop"})
)
request = {
"token_ids": [1, 2, 3],
"multi_modal_data": {"image_url": [{"Url": "http://img.png"}]},
"sampling_options": {},
"stop_conditions": {},
"output_options": {},
"prefill_result": {
"disaggregated_params": {
"kv_transfer_params": {"block_ids": [0]},
},
},
}
context = MagicMock()
chunks = []
async for chunk in handler._generate_token_mode(request, context, "req-1"):
chunks.append(chunk)
# Should reach _build_prompt_from_request (not error at decode guard)
assert len(chunks) == 1
assert chunks[0]["message"] == "test stop"
async def test_aggregated_mode_calls_extract_multimodal_data(self):
"""Aggregated mode handler calls _extract_multimodal_data normally."""
handler = _make_decode_handler(disaggregation_mode="AGGREGATED")
handler._extract_multimodal_data = AsyncMock(return_value=None)
# Return an error from _build_prompt_from_request so _generate_token_mode
# yields it and returns early — no need to mock the engine.
handler._build_prompt_from_request = MagicMock(
return_value=(None, None, {"status": "error", "message": "test stop"})
)
request = {
"token_ids": [1, 2, 3],
"multi_modal_data": {"image_url": [{"Url": "http://img.png"}]},
"sampling_options": {},
"stop_conditions": {},
"output_options": {},
}
context = MagicMock()
chunks = []
async for chunk in handler._generate_token_mode(request, context, "req-1"):
chunks.append(chunk)
handler._extract_multimodal_data.assert_awaited_once()
assert len(chunks) == 1
assert chunks[0]["status"] == "error"
# ── Prefill _build_embedding_params tests ──────────────────────────
def _make_prefill_handler(model: str = "test-model") -> mod.PrefillWorkerHandler:
"""Construct a PrefillWorkerHandler with mocked internals."""
config = _make_config(
model=model, is_prefill_worker=True, disaggregation_mode="PREFILL"
)
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
handler = mod.PrefillWorkerHandler(
runtime=MagicMock(),
config=config,
engine=MagicMock(),
default_sampling_params={},
)
handler.config = config
return handler
class TestBuildEmbeddingParams:
"""Tests for PrefillWorkerHandler._build_embedding_params."""
def test_dict_image_data_produces_embedding_params(self):
"""Dict-style image data with image_embeds + image_grid_thw -> valid params."""
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
mm_data = {
"image": {
"image_embeds": torch.randn(1, 256, 1024),
"image_grid_thw": torch.tensor([[1, 16, 16]]),
}
}
result = handler._build_embedding_params(mm_data)
assert result is not None
assert "image_grid_thw" in result
assert "embeddings_shape" in result
assert result["embeddings_shape"] == [1, 256, 1024]
def test_pil_image_list_non_qwen_returns_none(self):
"""PIL image list for non-Qwen model -> returns None."""
handler = _make_prefill_handler(model="llava-hf/llava-1.5-7b-hf")
mm_data = {"image": [MagicMock()]}
result = handler._build_embedding_params(mm_data)
assert result is None
def test_no_image_data_returns_none(self):
"""No image data -> returns None."""
handler = _make_prefill_handler(model="Qwen/Qwen3-VL-2B-Instruct")
mm_data = {}
result = handler._build_embedding_params(mm_data)
assert result is None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment