Unverified Commit cb88fdc7 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: add encode client and embedding cache to PD worker (#6029)

parent 166e1f4d
...@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -27,12 +27,17 @@ class DynamoRuntimeConfig(ConfigBase):
custom_jinja_template: Optional[str] = None custom_jinja_template: Optional[str] = None
endpoint_types: str endpoint_types: str
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
multimodal_embedding_cache_capacity_gb: float
def validate(self) -> None: def validate(self) -> None:
# TODO get a better way for spot fixes like this. # TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events self.enable_local_indexer = not self.durable_kv_events
# For simplicity, we do not prepend "dyn-" unless it's absolutely necessary. These are
# exemplary exceptions:
# - To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific
# args.
class DynamoRuntimeArgGroup(ArgGroup): class DynamoRuntimeArgGroup(ArgGroup):
"""Dynamo runtime configuration parameters (common to all backends).""" """Dynamo runtime configuration parameters (common to all backends)."""
...@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -89,7 +94,6 @@ class DynamoRuntimeArgGroup(ArgGroup):
) )
# Optional: tool/reasoning parsers (choices from dynamo._core when available) # Optional: tool/reasoning parsers (choices from dynamo._core when available)
# To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific args
add_argument( add_argument(
g, g,
flag_name="--dyn-tool-call-parser", flag_name="--dyn-tool-call-parser",
...@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -130,3 +134,12 @@ class DynamoRuntimeArgGroup(ArgGroup):
default=None, default=None,
help="Dump resolved configuration to the specified file path.", help="Dump resolved configuration to the specified file path.",
) )
add_argument(
g,
flag_name="--multimodal-embedding-cache-capacity-gb",
env_var="DYN_MULTIMODAL_EMBEDDING_CACHE_CAPACITY_GB",
default=0,
arg_type=float,
help="Capacity of the multimodal embedding cache in GB. 0 = disabled.",
)
...@@ -79,7 +79,6 @@ def parse_args() -> Config: ...@@ -79,7 +79,6 @@ def parse_args() -> Config:
Returns: Returns:
Config: Parsed configuration object. Config: Parsed configuration object.
""" """
dynamo_runtime_argspec = DynamoRuntimeArgGroup() dynamo_runtime_argspec = DynamoRuntimeArgGroup()
dynamo_vllm_argspec = DynamoVllmArgGroup() dynamo_vllm_argspec = DynamoVllmArgGroup()
......
...@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import ( ...@@ -5,14 +5,14 @@ from dynamo.vllm.multimodal_handlers.encode_worker_handler import (
EncodeWorkerHandler, EncodeWorkerHandler,
VLLMEncodeWorkerHandler, VLLMEncodeWorkerHandler,
) )
from dynamo.vllm.multimodal_handlers.multimodal_pd_worker_handler import (
MultimodalPDWorkerHandler,
)
from dynamo.vllm.multimodal_handlers.preprocessed_handler import ( from dynamo.vllm.multimodal_handlers.preprocessed_handler import (
ECProcessorHandler, ECProcessorHandler,
PreprocessedHandler, PreprocessedHandler,
) )
from dynamo.vllm.multimodal_handlers.worker_handler import ( from dynamo.vllm.multimodal_handlers.worker_handler import MultimodalDecodeWorkerHandler
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
)
__all__ = [ __all__ = [
"EncodeWorkerHandler", "EncodeWorkerHandler",
......
...@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model ...@@ -35,19 +35,6 @@ from ..multimodal_utils.model import is_qwen_vl_model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8 CACHE_SIZE_MAXIMUM = 8
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import logging
from collections import defaultdict
from typing import Any
import torch
from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config
from ..handlers import BaseWorkerHandler
from ..multimodal_utils import ImageLoader, MyRequestOutput, vLLMMultimodalRequest
from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
logger = logging.getLogger(__name__)
class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def __init__(
self,
runtime,
component: Component,
engine_client: AsyncLLM,
config: Config,
encode_worker_client: Client | None = None,
decode_worker_client: Client | None = None,
shutdown_event=None,
):
# Get default_sampling_params from config
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
component,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
)
self.config = config
self.encode_worker_client = encode_worker_client
self.decode_worker_client = decode_worker_client
self.enable_disagg = config.is_prefill_worker
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
config.multimodal_embedding_cache_capacity_gb * 1024**3
)
self.embedding_cache_manager = MultimodalEmbeddingCacheManager(
capacity_bytes
)
# Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.")
if "video" in self.config.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
self.image_loader = ImageLoader()
logger.info("Multimodal PD Worker has been initialized")
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in request.multimodal_inputs:
if mi.multimodal_input.image_url:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
else:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings = await load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self._connector,
)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
request.image_grid_thw = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg and self.decode_worker_client:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg and self.decode_worker_client:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model):
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
# logger.info(f"Response outputs: {response.outputs}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
import logging import logging
from collections import defaultdict
from typing import Any
import torch
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.v1.engine.async_llm import AsyncLLM
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.runtime import Client, Component, DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..args import Config
from ..handlers import BaseWorkerHandler from ..handlers import BaseWorkerHandler
from ..multimodal_utils import ImageLoader, MyRequestOutput, vLLMMultimodalRequest from ..multimodal_utils import MyRequestOutput, vLLMMultimodalRequest
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import (
accumulate_embeddings,
load_embeddings,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -32,7 +24,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
runtime, runtime,
component, component,
engine_client, engine_client,
config, config: Config,
shutdown_event=None, shutdown_event=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
...@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -111,204 +103,3 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
metrics=response.metrics, metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params, kv_transfer_params=response.kv_transfer_params,
).model_dump_json() ).model_dump_json()
class MultimodalPDWorkerHandler(BaseWorkerHandler):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def __init__(
self,
runtime,
component: Component,
engine_client: AsyncLLM,
config,
decode_worker_client: Client | None = None,
shutdown_event=None,
):
# Get default_sampling_params from config
default_sampling_params = (
config.engine_args.create_model_config().get_diff_sampling_param()
)
# Call BaseWorkerHandler.__init__ with proper parameters
super().__init__(
runtime,
component,
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
)
self.config = config
self.decode_worker_client = decode_worker_client
self.enable_disagg = config.is_prefill_worker
# Initialize multimodal-specific components
logger.info("Multimodal PD Worker startup started.")
if "video" in self.config.model.lower():
self.EMBEDDINGS_DTYPE = torch.uint8
else:
self.EMBEDDINGS_DTYPE = torch.float16
self.EMBEDDINGS_DEVICE = "cpu"
# Create and initialize a dynamo connector for this worker.
# We'll need this to move data between this worker and remote workers efficiently.
# Note: This is synchronous initialization, async initialization happens in async_init
self._connector: connect.Connector | None = (
None # Will be initialized in async_init
)
self.image_loader = ImageLoader()
logger.info("Multimodal PD Worker has been initialized")
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization for connector that requires async setup"""
# Initialize the connector asynchronously
self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
logger.debug(f"Got raw request: {request}")
if type(request) is not vLLMMultimodalRequest:
if type(request) is str:
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in request.multimodal_inputs:
if mi.multimodal_input.image_url:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
else:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings = await load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self._connector,
)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape
# from the constructed multimodal data so decode can reconstruct its
# multi_modal_data consistently for multiple images.
if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict
):
image_data = multi_modal_data["image"]
image_grid_thw = image_data.get("image_grid_thw")
image_embeds = image_data.get("image_embeds")
if image_grid_thw is not None:
request.image_grid_thw = (
image_grid_thw.tolist()
if isinstance(image_grid_thw, torch.Tensor)
else image_grid_thw
)
if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape)
# Remove the image features from the request as they are not required
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade
request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.info(f"{multi_modal_data}")
# Deepcopy the request to avoid modifying the original
# when we adjust sampling params for prefill
pd_request = copy.deepcopy(request)
# Do prefill and remote decode if enable_disagg is true
if self.enable_disagg and self.decode_worker_client:
extra_args = pd_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=pd_request.sampling_params,
request_id=pd_request.request_id,
)
if self.enable_disagg and self.decode_worker_client:
decode_request = copy.deepcopy(request)
async for prefill_response in gen:
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt.
# The decode worker will pass multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# For other models: Use the expanded prompt from prefill response.
# These models don't pass multi_modal_data in decode, so they need
# the already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model):
decode_request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request)
async for (
decode_response
) in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
yield MyRequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
else:
async for response in gen:
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
# logger.info(f"Response outputs: {response.outputs}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for MultimodalPDWorkerHandler.__init__."""
from unittest.mock import MagicMock, patch
import pytest
from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.multimodal,
]
def _make_config(
model: str = "test-model",
is_prefill_worker: bool = False,
enable_multimodal: bool = True,
multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.__init__."""
config = MagicMock()
config.model = model
config.is_prefill_worker = is_prefill_worker
config.enable_multimodal = enable_multimodal
config.multimodal_embedding_cache_capacity_gb = (
multimodal_embedding_cache_capacity_gb
)
config.engine_args.create_model_config.return_value.get_diff_sampling_param.return_value = (
{}
)
return config
class TestMultimodalPDWorkerHandlerInit:
"""Tests for MultimodalPDWorkerHandler.__init__ focusing on embedding cache."""
def test_init_with_embedding_cache(self):
"""When capacity > 0, a MultimodalEmbeddingCacheManager is created with correct byte size."""
capacity_gb = 0.1
config = _make_config(multimodal_embedding_cache_capacity_gb=capacity_gb)
with (
patch.object(mod.BaseWorkerHandler, "__init__", return_value=None),
patch.object(mod, "ImageLoader", new_callable=MagicMock),
):
handler = mod.MultimodalPDWorkerHandler(
runtime=MagicMock(),
component=MagicMock(),
engine_client=MagicMock(),
config=config,
)
assert isinstance(
handler.embedding_cache_manager, MultimodalEmbeddingCacheManager
)
expected_bytes = int(capacity_gb * 1024**3)
assert handler.embedding_cache_manager._capacity_bytes == expected_bytes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment