Unverified Commit d2a57839 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

refactor: delete handlers and disagg EC producer/consumer (#6051)

parent 976bb70a
...@@ -279,7 +279,7 @@ def cmd_line_args(): ...@@ -279,7 +279,7 @@ def cmd_line_args():
"--encode-endpoint", "--encode-endpoint",
type=str, type=str,
default="", default="",
help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) for the encode worker. Default: {DEFAULT_ENCODE_ENDPOINT}", help=f"Endpoint(in 'dyn://namespace.component.endpoint' format) for the encode worker. e.g. {DEFAULT_ENCODE_ENDPOINT}",
) )
parser.add_argument( parser.add_argument(
"--allowed-local-media-path", "--allowed-local-media-path",
......
...@@ -153,12 +153,11 @@ def update_dynamo_config_with_engine( ...@@ -153,12 +153,11 @@ def update_dynamo_config_with_engine(
dynamo_config.served_model_name = None dynamo_config.served_model_name = None
# TODO: move to "disaggregation_mode" as the other engines. # TODO: move to "disaggregation_mode" as the other engines.
if dynamo_config.multimodal_processor or dynamo_config.ec_processor: if dynamo_config.route_to_encoder:
dynamo_config.component = "processor" dynamo_config.component = "processor"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif ( elif (
dynamo_config.vllm_native_encoder_worker dynamo_config.multimodal_encode_worker
or dynamo_config.multimodal_encode_worker
or dynamo_config.multimodal_encode_prefill_worker or dynamo_config.multimodal_encode_prefill_worker
): ):
dynamo_config.component = "encoder" dynamo_config.component = "encoder"
......
...@@ -62,17 +62,10 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -62,17 +62,10 @@ class DynamoVllmArgGroup(ArgGroup):
# Multimodal # Multimodal
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--multimodal-processor", flag_name="--route-to-encoder",
env_var="DYN_VLLM_MULTIMODAL_PROCESSOR", env_var="DYN_VLLM_ROUTE_TO_ENCODER",
default=False, default=False,
help="Run as multimodal processor component for handling multimodal requests.", help="Enable routing to separate encoder workers for multimodal processing.",
)
add_negatable_bool_argument(
g,
flag_name="--ec-processor",
env_var="DYN_VLLM_EC_PROCESSOR",
default=False,
help="Run as ECConnector processor (routes multimodal requests to encoder then PD workers).",
) )
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
...@@ -136,43 +129,6 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -136,43 +129,6 @@ class DynamoVllmArgGroup(ArgGroup):
), ),
) )
# vLLM-native encoder (ECConnector)
add_negatable_bool_argument(
g,
flag_name="--vllm-native-encoder-worker",
env_var="DYN_VLLM_NATIVE_ENCODER_WORKER",
default=False,
help="Run as vLLM-native encoder worker using ECConnector for encoder disaggregation (requires shared storage). The following flags only work when this flag is enabled: --ec-connector-backend, --ec-storage-path, --ec-extra-config, --ec-consumer-mode.",
)
add_argument(
g,
flag_name="--ec-connector-backend",
env_var="DYN_VLLM_EC_CONNECTOR_BACKEND",
default="ECExampleConnector",
help="ECConnector implementation class for encoder disaggregation.",
)
add_argument(
g,
flag_name="--ec-storage-path",
env_var="DYN_VLLM_EC_STORAGE_PATH",
default=None,
help="Storage path for ECConnector (required for ECExampleConnector, optional for other backends).",
)
add_argument(
g,
flag_name="--ec-extra-config",
env_var="DYN_VLLM_EC_EXTRA_CONFIG",
default=None,
help="Additional ECConnector configuration as JSON string.",
)
add_negatable_bool_argument(
g,
flag_name="--ec-consumer-mode",
env_var="DYN_VLLM_EC_CONSUMER_MODE",
default=False,
help="Configure as ECConnector consumer for receiving encoder embeddings (for PD workers).",
)
# vLLM-Omni # vLLM-Omni
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
...@@ -210,8 +166,7 @@ class DynamoVllmConfig(ConfigBase): ...@@ -210,8 +166,7 @@ class DynamoVllmConfig(ConfigBase):
sleep_mode_level: int sleep_mode_level: int
# Multimodal # Multimodal
multimodal_processor: bool route_to_encoder: bool
ec_processor: bool
multimodal_encode_worker: bool multimodal_encode_worker: bool
multimodal_worker: bool multimodal_worker: bool
multimodal_decode_worker: bool multimodal_decode_worker: bool
...@@ -220,13 +175,6 @@ class DynamoVllmConfig(ConfigBase): ...@@ -220,13 +175,6 @@ class DynamoVllmConfig(ConfigBase):
mm_prompt_template: str mm_prompt_template: str
frontend_decoding: bool frontend_decoding: bool
# vLLM-native encoder (ECConnector)
vllm_native_encoder_worker: bool
ec_connector_backend: str
ec_storage_path: Optional[str] = None
ec_extra_config: Optional[str] = None
ec_consumer_mode: bool
# vLLM-Omni # vLLM-Omni
omni: bool omni: bool
stage_configs_path: Optional[str] = None stage_configs_path: Optional[str] = None
...@@ -239,7 +187,6 @@ class DynamoVllmConfig(ConfigBase): ...@@ -239,7 +187,6 @@ class DynamoVllmConfig(ConfigBase):
self._validate_prefill_decode_exclusive() self._validate_prefill_decode_exclusive()
self._validate_multimodal_role_exclusivity() self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag() self._validate_multimodal_requires_flag()
self._validate_ec_connector_storage()
self._validate_omni_stage_config() self._validate_omni_stage_config()
def _validate_prefill_decode_exclusive(self) -> None: def _validate_prefill_decode_exclusive(self) -> None:
...@@ -250,16 +197,16 @@ class DynamoVllmConfig(ConfigBase): ...@@ -250,16 +197,16 @@ class DynamoVllmConfig(ConfigBase):
) )
def _count_multimodal_roles(self) -> int: def _count_multimodal_roles(self) -> int:
"""Return the number of multimodal roles set (0 or 1 allowed).""" """Return the number of multimodal worker roles set (0 or 1 allowed).
Note: --route-to-encoder is a modifier flag, not a worker type.
"""
return sum( return sum(
[ [
bool(self.multimodal_processor),
bool(self.ec_processor),
bool(self.multimodal_encode_worker), bool(self.multimodal_encode_worker),
bool(self.multimodal_worker), bool(self.multimodal_worker),
bool(self.multimodal_decode_worker), bool(self.multimodal_decode_worker),
bool(self.multimodal_encode_prefill_worker), bool(self.multimodal_encode_prefill_worker),
bool(self.vllm_native_encoder_worker),
] ]
) )
...@@ -267,10 +214,8 @@ class DynamoVllmConfig(ConfigBase): ...@@ -267,10 +214,8 @@ class DynamoVllmConfig(ConfigBase):
"""Ensure only one multimodal role is set at a time.""" """Ensure only one multimodal role is set at a time."""
if self._count_multimodal_roles() > 1: if self._count_multimodal_roles() > 1:
raise ValueError( raise ValueError(
"Only one multimodal role can be set at a time: " "Use only one of --multimodal-encode-worker, --multimodal-worker, "
"multimodal-processor, ec-processor, multimodal-encode-worker, " "--multimodal-decode-worker, --multimodal-encode-prefill-worker"
"multimodal-worker, multimodal-decode-worker, "
"multimodal-encode-prefill-worker, vllm-native-encoder-worker"
) )
def _validate_multimodal_requires_flag(self) -> None: def _validate_multimodal_requires_flag(self) -> None:
...@@ -280,18 +225,6 @@ class DynamoVllmConfig(ConfigBase): ...@@ -280,18 +225,6 @@ class DynamoVllmConfig(ConfigBase):
"Use --enable-multimodal when enabling any multimodal component" "Use --enable-multimodal when enabling any multimodal component"
) )
def _validate_ec_connector_storage(self) -> None:
"""Require ec_storage_path when using ECExampleConnector backend."""
if self.vllm_native_encoder_worker:
if (
self.ec_connector_backend == "ECExampleConnector"
and not self.ec_storage_path
):
raise ValueError(
"--ec-storage-path is required when using ECExampleConnector backend. "
"Specify a shared storage path for encoder cache."
)
def _validate_omni_stage_config(self) -> None: def _validate_omni_stage_config(self) -> None:
"""Require stage_configs_path when using --omni.""" """Require stage_configs_path when using --omni."""
if self.stage_configs_path and not self.omni: if self.stage_configs_path and not self.omni:
......
...@@ -45,14 +45,10 @@ except ImportError: ...@@ -45,14 +45,10 @@ except ImportError:
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import ( from dynamo.vllm.multimodal_handlers import (
ECProcessorHandler,
EncodeWorkerHandler, EncodeWorkerHandler,
MultimodalDecodeWorkerHandler, MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler, MultimodalPDWorkerHandler,
PreprocessedHandler,
VLLMEncodeWorkerHandler,
) )
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
from .args import Config, parse_args from .args import Config, parse_args
from .chrek import get_checkpoint_config from .chrek import get_checkpoint_config
...@@ -151,16 +147,7 @@ async def worker(): ...@@ -151,16 +147,7 @@ async def worker():
) )
# Route to appropriate initialization based on config flags # Route to appropriate initialization based on config flags
if config.vllm_native_encoder_worker: if config.multimodal_encode_worker:
await init_vllm_native_encoder(runtime, config, shutdown_event)
logger.debug("init_vllm_native_encoder completed")
elif config.ec_processor:
await init_ec_processor(runtime, config, shutdown_event)
logger.debug("init_ec_processor completed")
elif config.multimodal_processor:
await init_multimodal_processor(runtime, config, shutdown_event)
logger.debug("init_multimodal_processor completed")
elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config, shutdown_event) await init_multimodal_encode_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_encode_worker completed") logger.debug("init_multimodal_encode_worker completed")
elif ( elif (
...@@ -936,67 +923,6 @@ def get_engine_cache_info(engine: AsyncLLM): ...@@ -936,67 +923,6 @@ def get_engine_cache_info(engine: AsyncLLM):
raise raise
async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal processor component"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
# Get encode worker client
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
pd_worker_client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = PreprocessedHandler(
config.engine_args,
encode_worker_client,
pd_worker_client,
)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
# Register the endpoint as entrypoint to a model
await register_model(
ModelInput.Tokens,
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info("Starting to serve the processor endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[
(prometheus_names.labels.MODEL, config.model),
(prometheus_names.labels.MODEL_NAME, config.model),
],
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_encode_worker( async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
): ):
...@@ -1005,23 +931,10 @@ async def init_multimodal_encode_worker( ...@@ -1005,23 +931,10 @@ async def init_multimodal_encode_worker(
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
# Get PD worker client
# In multimodal mode, the PD worker always registers as "backend"
# (even in disaggregated mode with prefill/decode split, we still connect to "backend")
pd_worker_client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
handler = EncodeWorkerHandler( handler = EncodeWorkerHandler(
config.engine_args, config.engine_args,
pd_worker_client,
) )
await handler.async_init(runtime) await handler.async_init(runtime)
logger.info("Waiting for PD Worker Instances ...")
await pd_worker_client.wait_for_instances()
logger.info("Starting to serve the encode worker endpoint...") logger.info("Starting to serve the encode worker endpoint...")
try: try:
...@@ -1041,150 +954,6 @@ async def init_multimodal_encode_worker( ...@@ -1041,150 +954,6 @@ async def init_multimodal_encode_worker(
handler.cleanup() handler.cleanup()
async def init_vllm_native_encoder(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize vLLM-native encoder worker component (ECConnector mode).
In this mode, vLLM handles encoder execution, caching, and storage automatically.
"""
# Create component and endpoint
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
# Configure ECTransferConfig for producer role
instance_id = 0
engine_id = f"{config.namespace}.{config.component}.encoder.{instance_id}"
# Configure encoder with producer role, it will be responsible for creating embeddings and storing them in the shared storage
ec_transfer_config = create_ec_transfer_config(
engine_id=engine_id,
ec_role="ec_producer",
ec_connector_backend=config.ec_connector_backend,
ec_storage_path=config.ec_storage_path,
ec_extra_config=config.ec_extra_config,
)
# Set ECTransferConfig on engine args
config.engine_args.ec_transfer_config = ec_transfer_config
# Setup vLLM engine
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = setup_vllm_engine(config)
# Initialize vLLM Native Encoder Worker Handler
handler = VLLMEncodeWorkerHandler(
runtime,
component,
engine_client,
config,
)
handler.add_temp_dir(prometheus_temp_dir)
# 5. No async init needed - vLLM handles everything
# await handler.async_init(runtime) # Not needed for ECConnector mode
logger.info("Starting to serve vLLM-native encoder endpoint...")
# 6. Serve endpoint
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[
(prometheus_names.labels.MODEL, config.model),
(prometheus_names.labels.MODEL_NAME, config.model),
],
),
)
except Exception as e:
logger.error(f"Failed to serve vLLM-native encoder endpoint: {e}")
raise
finally:
handler.cleanup()
async def init_ec_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize ECConnector processor component.
Simple processor that routes multimodal requests using ECConnector pattern:
1. Preprocess request (same as regular processor)
2. Send multimodal items to encoder workers (stores to shared storage)
3. Forward preprocessed request to PD worker (loads from shared storage)
4. Stream response back to client
"""
# Create component and endpoint
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
# Get encoder worker client
encoder_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
# Get PD worker client
pd_client = (
await runtime.namespace(config.namespace)
.component("backend")
.endpoint("generate")
.client()
)
# Get prompt template from args (must be passed via environment or command line)
mm_prompt_template = config.mm_prompt_template
# Create EC processor handler (with preprocessing like regular processor)
handler = ECProcessorHandler(
config.engine_args,
encoder_worker_client=encoder_client,
pd_worker_client=pd_client,
prompt_template=mm_prompt_template,
)
logger.info("Waiting for encoder and PD worker instances...")
await encoder_client.wait_for_instances()
await pd_client.wait_for_instances()
# Register the endpoint as entrypoint to a model (same as preprocessed_handler)
await register_model(
ModelInput.Tokens, # Use Rust tokenization for better performance and multi-image support
ModelType.Chat,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
)
logger.info("Starting to serve EC processor endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[
(prometheus_names.labels.MODEL, config.model),
(prometheus_names.labels.MODEL_NAME, config.model),
],
),
)
except Exception as e:
logger.error(f"Failed to serve EC processor endpoint: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_worker( async def init_multimodal_worker(
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
...@@ -1194,39 +963,20 @@ async def init_multimodal_worker( ...@@ -1194,39 +963,20 @@ async def init_multimodal_worker(
""" """
Initialize multimodal worker component. Initialize multimodal worker component.
Supports two modes: Supports three modes:
1. --multimodal-worker: Receives embeddings from separate encoder 1. --multimodal-worker: Prefill+decode worker for multimodal LLM; can route
2. --multimodal-encode-prefill-worker: Handles inline encoding (e.g., Llama 4) to a separate encoder (--route-to-encoder) for embeddings. Runs
aggregated (P+D) or disaggregated (P→D).
Both can operate in aggregated (P+D) or disaggregated (P→D) mode. 2. --multimodal-decode-worker: Decode-only worker in disaggregated (P→D)
mode.
When --ec-consumer-mode is enabled, configures as ECConnector consumer 3. --multimodal-encode-prefill-worker: Unified encode+prefill+decode in one
to load encoder embeddings from shared storage. worker for models with integrated image encoding (e.g., Llama 4).
""" """
component = runtime.namespace(config.namespace).component(config.component) component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
# Configure ECConnector consumer mode if enabled
if config.ec_consumer_mode:
logger.info("Configuring as ECConnector consumer for encoder embeddings")
instance_id = 0
engine_id = f"{config.namespace}.{config.component}.backend.{instance_id}"
# The PD Worker just load the embeddings from the shared storage, so it is a consumer
ec_transfer_config = create_ec_transfer_config(
engine_id=engine_id,
ec_role="ec_consumer",
ec_connector_backend=config.ec_connector_backend,
ec_storage_path=config.ec_storage_path,
ec_extra_config=config.ec_extra_config,
)
# Set ECTransferConfig on engine args
config.engine_args.ec_transfer_config = ec_transfer_config
logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}")
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None: if pre_created_engine is not None:
( (
...@@ -1245,6 +995,20 @@ async def init_multimodal_worker( ...@@ -1245,6 +995,20 @@ async def init_multimodal_worker(
_component_gauges, _component_gauges,
) = setup_vllm_engine(config) ) = setup_vllm_engine(config)
# Set up encode worker client when routing to encoder is enabled
# (PD worker handles encode routing directly instead of a separate processor)
encode_worker_client = None
if config.route_to_encoder:
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
# Set up decode worker client for disaggregated mode # Set up decode worker client for disaggregated mode
decode_worker_client = None decode_worker_client = None
if config.is_prefill_worker: if config.is_prefill_worker:
...@@ -1269,6 +1033,7 @@ async def init_multimodal_worker( ...@@ -1269,6 +1033,7 @@ async def init_multimodal_worker(
component, component,
engine_client, engine_client,
config, config,
encode_worker_client,
decode_worker_client, decode_worker_client,
shutdown_event, shutdown_event,
) )
...@@ -1283,9 +1048,21 @@ async def init_multimodal_worker( ...@@ -1283,9 +1048,21 @@ async def init_multimodal_worker(
if kv_publisher: if kv_publisher:
handler.kv_publisher = kv_publisher handler.kv_publisher = kv_publisher
# Register model with the frontend so it can route requests
model_type = parse_endpoint_types(config.endpoint_types)
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
await register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
metrics_labels = [ metrics_labels = [
(prometheus_names.labels.MODEL, config.model), (prometheus_names.labels.MODEL, config.served_model_name or config.model),
(prometheus_names.labels.MODEL_NAME, config.model), (prometheus_names.labels.MODEL_NAME, config.served_model_name or config.model),
] ]
try: try:
await asyncio.gather( await asyncio.gather(
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from dynamo.vllm.multimodal_handlers.encode_worker_handler import ( from dynamo.vllm.multimodal_handlers.encode_worker_handler import EncodeWorkerHandler
EncodeWorkerHandler,
VLLMEncodeWorkerHandler,
)
from dynamo.vllm.multimodal_handlers.multimodal_pd_worker_handler import ( from dynamo.vllm.multimodal_handlers.multimodal_pd_worker_handler import (
MultimodalPDWorkerHandler, MultimodalPDWorkerHandler,
) )
from dynamo.vllm.multimodal_handlers.preprocessed_handler import (
ECProcessorHandler,
PreprocessedHandler,
)
from dynamo.vllm.multimodal_handlers.worker_handler import MultimodalDecodeWorkerHandler from dynamo.vllm.multimodal_handlers.worker_handler import MultimodalDecodeWorkerHandler
__all__ = [ __all__ = [
"EncodeWorkerHandler", "EncodeWorkerHandler",
"VLLMEncodeWorkerHandler",
"PreprocessedHandler",
"ECProcessorHandler",
"MultimodalPDWorkerHandler", "MultimodalPDWorkerHandler",
"MultimodalDecodeWorkerHandler", "MultimodalDecodeWorkerHandler",
] ]
...@@ -4,27 +4,21 @@ ...@@ -4,27 +4,21 @@
import asyncio import asyncio
import logging import logging
import os import os
import shutil
import tempfile import tempfile
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncGenerator, AsyncIterator from typing import AsyncIterator
import safetensors import safetensors
import torch import torch
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import TokensPrompt
from vllm.multimodal.hasher import MultiModalHasher
from vllm.sampling_params import SamplingParams
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..multimodal_utils import ( from ..multimodal_utils import (
ImageLoader, ImageLoader,
VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse,
encode_image_embeddings, encode_image_embeddings,
get_encoder_components, get_encoder_components,
load_vision_model, load_vision_model,
...@@ -51,9 +45,7 @@ class EncodeWorkerHandler: ...@@ -51,9 +45,7 @@ class EncodeWorkerHandler:
def __init__( def __init__(
self, self,
engine_args: AsyncEngineArgs, engine_args: AsyncEngineArgs,
pd_worker_client: Client,
) -> None: ) -> None:
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args self.engine_args = engine_args
self.model = self.engine_args.model self.model = self.engine_args.model
...@@ -266,148 +258,3 @@ class EncodeWorkerHandler: ...@@ -266,148 +258,3 @@ class EncodeWorkerHandler:
except Exception as e: except Exception as e:
logger.error(f"Error processing request {request_id}: {e}") logger.error(f"Error processing request {request_id}: {e}")
raise raise
class VLLMEncodeWorkerHandler:
"""
Handler for vLLM-native encoder worker using ECConnector.
"""
def __init__(self, runtime, component, engine_client, config):
"""
Initialize the handler.
Args:
runtime: Dynamo distributed runtime
component: Dynamo component instance
engine_client: vLLM AsyncLLM instance
config: Dynamo Config object with CLI arguments
"""
self.runtime = runtime
self.component = component
self.engine_client = engine_client
self.config = config
self.temp_dirs = []
self.image_loader = ImageLoader()
logger.info(
f"VLLMNativeEncoderWorkerHandler initialized with "
f"backend={config.ec_connector_backend}, "
f"storage_path={config.ec_storage_path}"
)
def add_temp_dir(self, temp_dir):
"""Add temporary directory for cleanup."""
if temp_dir:
self.temp_dirs.append(temp_dir)
async def generate(self, request, context) -> AsyncGenerator[str, None]:
"""
Process encoder request and trigger vLLM encoder execution.
Args:
request: VLLMNativeEncoderRequest with multimodal_inputs (list of MultiModalGroup)
context: Request context from Dynamo runtime
Yields:
JSON-encoded VLLMNativeEncoderResponse for each processed item
"""
# Parse request
if not isinstance(request, VLLMNativeEncoderRequest):
if isinstance(request, str):
request = VLLMNativeEncoderRequest.model_validate_json(request)
else:
request = VLLMNativeEncoderRequest.model_validate(request)
if not request.multimodal_inputs:
raise ValueError("No multimodal inputs provided in request")
logger.info(
f"Processing {len(request.multimodal_inputs)} multimodal item(s) "
f"for request_id={request.request_id}"
)
# Load all images
# TODO: support video and audio encoding later
media_list = []
modality = "image"
for idx, mm_group in enumerate(request.multimodal_inputs):
mm_input = mm_group.multimodal_input
if mm_input.image_url:
media = await self.image_loader.load_image(mm_input.image_url)
media_list.append(media)
elif mm_input.video_url:
raise NotImplementedError("Video encoding not yet supported")
else:
raise ValueError(
f"No media URL provided in multimodal_input[{idx}]. "
"Specify image_url or video_url."
)
# Process all images in one vLLM request
prompt_dict = TokensPrompt(
prompt_token_ids=request.token_ids,
multi_modal_data={"image": media_list},
)
try:
gen = self.engine_client.generate(
prompt=prompt_dict,
sampling_params=SamplingParams(max_tokens=1, min_tokens=0),
request_id=request.request_id,
)
# Consume generator to trigger encoder execution
async for _ in gen:
pass
logger.info(
f"[{request.request_id}] Encoder execution completed for all {len(media_list)} image(s)"
)
except Exception as e:
logger.error(f"[{request.request_id}] Encoder execution failed: {e}")
raise
# Compute mm_hash for each image and yield responses
for idx, media in enumerate(media_list):
item_request_id = f"{request.request_id}_mm_{idx}"
try:
mm_hash = MultiModalHasher.hash_kwargs(
model_id=self.config.model, image=media
)
logger.debug(f"[{item_request_id}] Computed mm_hash: {mm_hash}")
except Exception as e:
logger.error(f"[{item_request_id}] Failed to compute mm_hash: {e}")
raise
response = VLLMNativeEncoderResponse(
request_id=item_request_id,
mm_hash=mm_hash,
modality=modality,
connector_metadata={
"ec_connector": self.config.ec_connector_backend,
"storage_path": self.config.ec_storage_path,
},
)
logger.debug(f"[{item_request_id}] Returning response: {response}")
yield response.model_dump_json()
logger.info(
f"All {len(request.multimodal_inputs)} multimodal items processed "
f"for request_id={request.request_id}"
)
def cleanup(self):
"""Cleanup resources."""
logger.info("Cleaning up VLLMNativeEncoderWorkerHandler")
# Clean up temporary directories
for temp_dir in self.temp_dirs:
try:
shutil.rmtree(temp_dir, ignore_errors=True)
logger.debug(f"Cleaned up temp directory: {temp_dir}")
except Exception as e:
logger.warning(f"Failed to cleanup {temp_dir}: {e}")
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import copy import copy
import logging import logging
import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
...@@ -17,11 +18,18 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import ( ...@@ -17,11 +18,18 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
from dynamo.runtime import Client, Component, DistributedRuntime from dynamo.runtime import Client, Component, DistributedRuntime
from ..args import Config from ..args import Config
from ..handlers import BaseWorkerHandler from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import ImageLoader, MyRequestOutput, vLLMMultimodalRequest from ..multimodal_utils import (
MultiModalGroup,
MyRequestOutput,
PatchedTokensPrompt,
vLLMMultimodalRequest,
)
from ..multimodal_utils.model import is_qwen_vl_model from ..multimodal_utils.model import is_qwen_vl_model
from ..multimodal_utils.prefill_worker_utils import ( from ..multimodal_utils.prefill_worker_utils import (
IMAGE_URL_KEY,
accumulate_embeddings, accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings, load_embeddings,
) )
...@@ -85,7 +93,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -85,7 +93,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector: connect.Connector | None = ( self._connector: connect.Connector | None = (
None # Will be initialized in async_init None # Will be initialized in async_init
) )
self.image_loader = ImageLoader()
logger.info("Multimodal PD Worker has been initialized") logger.info("Multimodal PD Worker has been initialized")
...@@ -95,43 +102,116 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -95,43 +102,116 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self._connector = connect.Connector() self._connector = connect.Connector()
logger.info("Multimodal PD Worker async initialization completed.") logger.info("Multimodal PD Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def _build_request_from_frontend(
logger.debug(f"Got raw request: {request}") self, raw_request: dict
if type(request) is not vLLMMultimodalRequest: ) -> vLLMMultimodalRequest:
if type(request) is str: """Convert a raw frontend dict into a vLLMMultimodalRequest.
request = vLLMMultimodalRequest.model_validate_json(request)
else: When the PD worker is the direct frontend endpoint (no separate
request = vLLMMultimodalRequest.model_validate(request) processor), the Rust frontend sends a dict representation of PreprocessedRequest.
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") This method extracts image URLs, routes them to encode workers if available,
and assembles the standard request object that the rest of ``generate()`` expects.
"""
request_id = str(uuid.uuid4().hex)
# Extract image URLs from the raw frontend dict
image_urls: list[str] = []
mm_data = raw_request.get("multi_modal_data")
if mm_data is not None:
for item in mm_data.get(IMAGE_URL_KEY, []):
if isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
multimodal_groups: list[MultiModalGroup] = []
if self.encode_worker_client and image_urls:
multimodal_groups = await fetch_embeddings_from_encode_workers(
self.encode_worker_client,
image_urls,
request_id,
)
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
return vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"]
),
sampling_params=sampling_params,
request_id=request_id,
multimodal_inputs=multimodal_groups,
)
# ── Request parsing ────────────────────────────────────────────────
async def _parse_request(self, request) -> vLLMMultimodalRequest:
"""Normalize any incoming format into a validated vLLMMultimodalRequest.
Handles three input shapes:
1. Raw frontend dict (has ``token_ids`` + ``multi_modal_data``)
2. JSON string (from encode worker or other serializers)
3. Plain dict (Pydantic-compatible mapping)
"""
if isinstance(request, dict) and "token_ids" in request:
return await self._build_request_from_frontend(request)
if type(request) is vLLMMultimodalRequest:
return request
if type(request) is str:
return vLLMMultimodalRequest.model_validate_json(request)
return vLLMMultimodalRequest.model_validate(request)
# ── Multimodal data loading ──────────────────────────────────────
async def _load_multimodal_data(
self, request: vLLMMultimodalRequest
) -> dict[str, Any]:
"""Load pre-computed embeddings into an engine-ready dict.
Each ``MultiModalGroup`` carries embeddings from encode workers,
loaded via NIXL RDMA or local safetensors.
No-op when --route-to-encoder is not set.
"""
multimodal_inputs: list[MultiModalGroup] = request.multimodal_inputs or []
multi_modal_data: dict[str, Any] = defaultdict(list) multi_modal_data: dict[str, Any] = defaultdict(list)
for mi in request.multimodal_inputs:
if mi.multimodal_input.image_url:
# PIL image path — used by both EC consumer mode
# (vLLM looks up cached embeddings via mm_hash) and
# non-disaggregated mode (vLLM encodes inline).
multi_modal_data["image"].append(
await self.image_loader.load_image(mi.multimodal_input.image_url)
)
else:
# Pre-computed embeddings via NIXL RDMA or local safetensors
embeddings = await load_embeddings(
mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self._connector,
)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
)
# For Qwen VL (mRoPE), capture the accumulated image grid + embedding shape for mi in multimodal_inputs:
# from the constructed multimodal data so decode can reconstruct its embeddings = await load_embeddings(
# multi_modal_data consistently for multiple images. mi,
self.EMBEDDINGS_DTYPE,
self.EMBEDDINGS_DEVICE,
self._connector,
)
accumulate_embeddings(
multi_modal_data,
self.config.model,
self.EMBEDDINGS_DTYPE,
embeddings,
mi.image_grid_thw,
)
return multi_modal_data
# ── Request metadata finalization ────────────────────────────────
def _finalize_request_metadata(
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
) -> None:
"""Attach model-specific metadata and strip heavy fields from request.
For Qwen VL (mRoPE) models, captures image grid dimensions and
embedding shapes so the decode worker can reconstruct
``multi_modal_data`` consistently for multiple images.
Also clears ``multimodal_inputs`` — the raw embeddings / URLs are no
longer needed once ``multi_modal_data`` is built.
"""
if is_qwen_vl_model(self.config.model) and isinstance( if is_qwen_vl_model(self.config.model) and isinstance(
multi_modal_data.get("image"), dict multi_modal_data.get("image"), dict
): ):
...@@ -147,93 +227,175 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -147,93 +227,175 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
if image_embeds is not None: if image_embeds is not None:
request.embeddings_shape = list(image_embeds.shape) request.embeddings_shape = list(image_embeds.shape)
# Remove the image features from the request as they are not required # Use empty list instead of None to satisfy Pydantic validation
# Use empty list instead of None to satisfy Pydantic validation on decode worker after vllm upgrade # on decode worker after vllm upgrade.
request.multimodal_inputs = [] request.multimodal_inputs = []
logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}") logger.info(f"Prepared multimodal data size: {len(multi_modal_data['image'])}")
logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys())) logger.debug("Multimodal data keys: %s", list(multi_modal_data.keys()))
# Deepcopy the request to avoid modifying the original # ── Response serialization ───────────────────────────────────────
# when we adjust sampling params for prefill
pd_request = copy.deepcopy(request) @staticmethod
# Do prefill and remote decode if enable_disagg is true def _serialize_response(response) -> str:
if self.enable_disagg and self.decode_worker_client: """Build a JSON-serialized ``MyRequestOutput`` from an engine response."""
extra_args = pd_request.sampling_params.extra_args or {} return MyRequestOutput(
extra_args["kv_transfer_params"] = { request_id=response.request_id,
"do_remote_decode": True, prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
@staticmethod
def _format_engine_output(
response, num_output_tokens_so_far: int
) -> dict[str, Any]:
"""Format a vLLM RequestOutput as an LLMEngineOutput-compatible dict.
This produces the same incremental dict format that the regular
(non-multimodal) handler yields, which the Rust frontend expects
after model registration.
"""
if not response.outputs:
return {
"finish_reason": "error: No outputs from vLLM engine",
"token_ids": [],
} }
pd_request.sampling_params.extra_args = extra_args
pd_request.sampling_params.max_tokens = 1
pd_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", pd_request) output = response.outputs[0]
out: dict[str, Any] = {
"token_ids": output.token_ids[num_output_tokens_so_far:],
}
if output.finish_reason:
# Inline normalization: map vLLM's "abort" to Dynamo's "cancelled"
finish_reason = output.finish_reason
if finish_reason.startswith("abort"):
finish_reason = "cancelled"
out["finish_reason"] = finish_reason
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=response,
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
return out
# ── Aggregated generation (prefill + decode locally) ─────────────
async def _generate_agg(
self,
request: vLLMMultimodalRequest,
multi_modal_data: dict[str, Any],
):
"""Run prefill and decode on this worker (aggregated mode)."""
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=pd_request.engine_prompt["prompt_token_ids"], prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data, multi_modal_data=multi_modal_data,
), ),
sampling_params=pd_request.sampling_params, sampling_params=request.sampling_params,
request_id=pd_request.request_id, request_id=request.request_id,
) )
if self.enable_disagg and self.decode_worker_client: num_output_tokens_so_far = 0
decode_request = copy.deepcopy(request) async for response in gen:
async for prefill_response in gen: logger.debug(f"Response kv_transfer_params: {response.kv_transfer_params}")
# For Qwen VL models with mRoPE: Keep the ORIGINAL unexpanded prompt. logger.debug(
# The decode worker will pass multi_modal_data which causes vLLM to f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
# expand the prompt identically to prefill, ensuring block counts match. )
# yield self._format_engine_output(response, num_output_tokens_so_far)
# For other models: Use the expanded prompt from prefill response. if response.outputs:
# These models don't pass multi_modal_data in decode, so they need num_output_tokens_so_far = len(response.outputs[0].token_ids)
# the already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model): # ── Disaggregated generation (prefill here, decode remote) ───────
decode_request.engine_prompt[
"prompt_token_ids" async def _generate_disagg(
] = prefill_response.prompt_token_ids self,
logger.debug( request: vLLMMultimodalRequest,
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}" multi_modal_data: dict[str, Any],
) ):
extra_args = decode_request.sampling_params.extra_args or {} """Prefill locally, then forward to a remote decode worker."""
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params # Prepare prefill-only request
extra_args.pop("serialized_request", None) prefill_only_request = copy.deepcopy(request)
decode_request.sampling_params.extra_args = extra_args extra_args = prefill_only_request.sampling_params.extra_args or {}
logger.debug("Decode request: %s", decode_request) extra_args["kv_transfer_params"] = {"do_remote_decode": True}
async for ( prefill_only_request.sampling_params.extra_args = extra_args
decode_response prefill_only_request.sampling_params.max_tokens = 1
) in await self.decode_worker_client.round_robin( prefill_only_request.sampling_params.min_tokens = 1
decode_request.model_dump_json() logger.debug("Prefill request: %s", prefill_only_request)
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined] gen = self.engine_client.generate(
yield MyRequestOutput( prompt=TokensPrompt(
request_id=output.request_id, prompt_token_ids=prefill_only_request.engine_prompt["prompt_token_ids"],
prompt=output.prompt, multi_modal_data=multi_modal_data,
prompt_token_ids=output.prompt_token_ids, ),
prompt_logprobs=output.prompt_logprobs, sampling_params=prefill_only_request.sampling_params,
outputs=output.outputs, request_id=prefill_only_request.request_id,
finished=output.finished, )
metrics=output.metrics,
kv_transfer_params=output.kv_transfer_params,
).model_dump_json()
# Drain prefill generator (max_tokens=1, expect a single response)
async for prefill_response in gen:
pass
# Qwen VL (mRoPE): keep the ORIGINAL unexpanded prompt.
# The decode worker passes multi_modal_data which causes vLLM to
# expand the prompt identically to prefill, ensuring block counts match.
#
# Other models: use the expanded prompt from prefill response.
# They don't pass multi_modal_data in decode, so they need the
# already-expanded prompt to match the KV cache layout.
if not is_qwen_vl_model(self.config.model):
request.engine_prompt[
"prompt_token_ids"
] = prefill_response.prompt_token_ids
logger.debug(
f"Prefill response kv_transfer_params: {prefill_response.kv_transfer_params}"
)
extra_args = request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
extra_args.pop("serialized_request", None)
request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", request)
# Serialized request is lightweight: token IDs, sampling params with
# kv_transfer_params, and small Qwen metadata (image_grid_thw,
# embeddings_shape). Heavy multimodal data was consumed locally by
# engine_client.generate() and multimodal_inputs was cleared by
# `_finalize_request_metadata`.
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
yield self._serialize_response(output)
# ── Public entry point ───────────────────────────────────────────
async def generate(self, request, context):
"""Parse the request, load multimodal data, and run inference."""
logger.debug(f"Got raw request: {request}")
request = await self._parse_request(request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
multi_modal_data = await self._load_multimodal_data(request)
self._finalize_request_metadata(request, multi_modal_data)
logger.info(
f"Prepared multimodal data size: {len(multi_modal_data.get('image', []))}"
)
logger.debug(f"{multi_modal_data}")
if self.enable_disagg and self.decode_worker_client:
async for chunk in self._generate_disagg(request, multi_modal_data):
yield chunk
else: else:
async for response in gen: async for chunk in self._generate_agg(request, multi_modal_data):
logger.debug( yield chunk
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
logger.debug(
f"length of expanded prompt ids: {len(response.prompt_token_ids)}"
)
# logger.info(f"Response outputs: {response.outputs}")
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import uuid
from collections import defaultdict
from enum import Enum
from typing import AsyncIterator, Final
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams as VllmSamplingParams
from dynamo.runtime import Client
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput,
PatchedTokensPrompt,
ProcessMixIn,
VLLMNativeEncoderRequest,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__)
# Multimodal data dictionary keys
IMAGE_URL_KEY: Final = "image_url"
VIDEO_URL_KEY: Final = "video_url"
URL_VARIANT_KEY: Final = "Url"
DECODED_VARIANT_KEY: Final = "Decoded"
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class PreprocessedHandler(ProcessMixIn):
"""
vLLM pre and post processing for multimodal requests
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
encode_worker_client: Client,
pd_worker_client: Client,
):
self.encode_worker_client = encode_worker_client
self.encode_worker_count = 0
self.pd_worker_client = pd_worker_client
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.default_sampling_params = self.model_config.get_diff_sampling_param()
self.stop = False
self._worker_count_task = asyncio.create_task(
self._update_encode_worker_count()
)
async def _update_encode_worker_count(self):
"""
Periodically updates the count of available encode workers.
"""
while self.stop is False:
try:
self.encode_worker_count = len(self.encode_worker_client.instance_ids())
logger.debug(f"Updated encode worker count: {self.encode_worker_count}")
except Exception as e:
logger.error(f"Failed to update encode worker count: {e}")
await asyncio.sleep(1) # Update every 1 second
def cleanup(self):
self.stop = True
if hasattr(self, "_worker_count_task"):
self._worker_count_task.cancel()
# Main method to parse the request and send the request to the vllm worker.
async def _generate(
self,
raw_request,
multimodal_inputs,
context,
):
# [gluo NOTE] panic for now as encoder here is for image only
if VIDEO_URL_KEY in multimodal_inputs or multimodal_inputs[VIDEO_URL_KEY]:
raise ValueError("Video URL not supported in encode worker yet")
request_id = str(uuid.uuid4().hex)
# Build sampling params from request using shared utility
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
# [gluo WIP] encoder doesn't really need any of this
encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
sampling_params=VllmSamplingParams(),
request_id=request_id,
multimodal_inputs=[],
)
# [gluo WIP] batching helps for encoding step to fully utilize GPU,
# should handle dispatch in a more intelligent way, i.e. splitting
# jobs based on availability of encode worker, rather than fixed mm
# mm item size per request. Also need to consider encoding load and
# balancing it between encoders.
if self.encode_worker_count == 0:
raise RuntimeError(
"No encode workers available to process multimodal input"
)
total_items = sum(len(urls) for urls in multimodal_inputs.values())
encode_batch_size = max(1, total_items // self.encode_worker_count)
encode_res_gen = []
for mm_type, urls in multimodal_inputs.items():
for url in urls:
multimodal_input = MultiModalInput()
if mm_type == IMAGE_URL_KEY:
multimodal_input.image_url = url
elif mm_type == VIDEO_URL_KEY:
multimodal_input.video_url = url
# [gluo NOTE] should not reach here due to earlier check
continue
encode_request.multimodal_inputs.append(
MultiModalGroup(multimodal_input=multimodal_input)
)
if len(encode_request.multimodal_inputs) >= encode_batch_size:
# model_dump_json() serializes the request to JSON string
# This API could accept Pydantic class, but SamplingParams
# in vLLMMultimodalRequest is not a Pydantic class and will
# cause TypeError: unsupported type SamplingParams
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
encode_request.multimodal_inputs = []
if encode_request.multimodal_inputs:
encode_res_gen.append(
await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
)
)
# Gather transformed requests
worker_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"]
),
sampling_params=sampling_params,
request_id=request_id,
multimodal_inputs=[], # will be filled in next
)
for encode_res in encode_res_gen:
async for response in encode_res:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
worker_request.multimodal_inputs.extend(output.multimodal_inputs)
response_generator = await self.pd_worker_client.round_robin( # type: ignore[call-arg]
worker_request.model_dump_json(), context=context
)
# [gluo FIXME] <im_end> being returned
async for output in self._generate_responses(response_generator):
yield output
# This method is used to process the responses from the engine generator.
async def _generate_responses(
self,
response_generator: AsyncIterator[RequestOutput],
):
# [gluo WIP] modified from handler.py (BaseWorkerHandler.generate_tokens)
num_output_tokens_so_far = 0
try:
async for resp in response_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data()) # type: ignore[attr-defined]
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
res = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if not res.outputs:
continue
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
# Extract logprobs for new tokens if available
log_probs, top_logprobs = BaseWorkerHandler._extract_logprobs(
output, num_output_tokens_so_far
)
if log_probs is not None:
out["log_probs"] = log_probs
if top_logprobs is not None:
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=res
)
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
def _extract_multimodal_data(self, request):
"""
Extract and decode multimodal data from PreprocessedRequest.
"""
# [gluo NOTE] modified from components/src/dynamo/vllm/handlers.py
if "multi_modal_data" not in request or request["multi_modal_data"] is None:
return {}
# [gluo FIXME] add this security option
# Security check: reject multimodal data if not explicitly enabled
# if not self.enable_multimodal:
# raise ValueError(
# "Received multimodal data but multimodal processing is not enabled. "
# "Use --enable-multimodal flag to enable multimodal processing."
# )
mm_map = request["multi_modal_data"]
multimodal_inputs = defaultdict(list)
for mm_type in [IMAGE_URL_KEY, VIDEO_URL_KEY]:
for item in mm_map.get(mm_type, []):
if isinstance(item, dict) and URL_VARIANT_KEY in item:
multimodal_inputs[mm_type].append(item[URL_VARIANT_KEY])
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
# Decoded support from PRs #3971/#3988 (frontend decoding + NIXL transfer)
# Will contain NIXL metadata for direct memory access
# TODO: Implement NIXL read when PRs merge
logger.warning(
"Decoded multimodal data not yet supported in standard worker"
)
return multimodal_inputs
# The generate endpoint will be used by the frontend to handle incoming requests.
async def generate(self, request, context):
logger.debug(f"Got preprocessed request: {request}")
# Extract multimodal inputs for dispatching to encode worker
multimodal_inputs = self._extract_multimodal_data(request)
if not multimodal_inputs:
raise ValueError("Either image URL or video URL is required")
elif len(multimodal_inputs) > 1:
raise ValueError(
"Only one of image URL or video URL is supported per request"
)
async for response in self._generate(request, multimodal_inputs, context):
yield response
class ECProcessorHandler(PreprocessedHandler):
"""
Processor handler for ECConnector-based encoder with pre-tokenized input support.
Inherits from PreprocessedHandler to reuse common pre-tokenized processing logic.
Uses ECConnector (vLLM-native encoder) instead of custom RDMA-based encoder.
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
encoder_worker_client: Client,
pd_worker_client: Client,
prompt_template: str | None = None,
):
"""
Initialize the ECConnector processor.
Args:
engine_args: vLLM engine arguments for model config
encoder_worker_client: Client for vLLM-native encoder worker endpoints
pd_worker_client: Client for PD worker endpoints (ECConnector consumer)
prompt_template: Optional prompt template (for reference, tokenization done by Rust)
"""
# Initialize base class
super().__init__(engine_args, encoder_worker_client, pd_worker_client)
self.prompt_template = prompt_template
logger.info(
"ECProcessorHandler initialized (inherits PreprocessedHandler, uses ECConnector)"
)
async def _generate(
self,
raw_request,
multimodal_inputs,
context,
):
"""
Generate responses using ECConnector encoder.
Overrides PreprocessedHandler._generate to use VLLMNativeEncoderRequest
instead of custom encoder protocol.
"""
# Extract token_ids from request (these contain placeholder tokens like 32000 for <image>)
token_ids = raw_request.get("token_ids", [])
if not token_ids:
raise ValueError("token_ids not found in request")
logger.info(
f"ECProcessor using token_ids (length={len(token_ids)}) with placeholders. "
f"Sample: {token_ids[:min(20, len(token_ids))]}"
)
# Check video not supported yet
if VIDEO_URL_KEY in multimodal_inputs and multimodal_inputs[VIDEO_URL_KEY]:
raise ValueError("Video URL not supported in ECConnector encoder yet")
request_id = str(uuid.uuid4().hex)
# Build sampling params from request
sampling_params = build_sampling_params(
raw_request, self.default_sampling_params
)
# Create multimodal groups for encoder
multimodal_groups = []
for mm_type, urls in multimodal_inputs.items():
for url in urls:
multimodal_input = MultiModalInput()
if mm_type == IMAGE_URL_KEY:
multimodal_input.image_url = url
elif mm_type == VIDEO_URL_KEY:
multimodal_input.video_url = url
multimodal_groups.append(
MultiModalGroup(multimodal_input=multimodal_input)
)
logger.info(
f"[{request_id}] Encoding {len(multimodal_groups)} multimodal item(s) "
f"via vLLM-native encoder (ECConnector)..."
)
# Send to vLLM-native encoder using VLLMNativeEncoderRequest
# Pass token_ids which already contain placeholder tokens (e.g., 32000 for <image> in LLaVA)
# The encoder worker will use TokensPrompt so vLLM can match placeholder token IDs
try:
encoder_request = VLLMNativeEncoderRequest(
request_id=request_id,
token_ids=token_ids, # Pass pre-tokenized input with placeholder tokens
multimodal_inputs=multimodal_groups,
)
request_json = encoder_request.model_dump_json()
response_stream = await self.encode_worker_client.round_robin(request_json)
# Consume encoder responses (embeddings written to ECConnector cache)
async for chunk in response_stream:
logger.debug(
f"[{request_id}] Received encoder response (embeddings cached)"
)
logger.info(f"[{request_id}] Encoder completed successfully for all items")
except Exception as e:
logger.error(f"[{request_id}] Encoder processing failed: {e}")
raise
# Create worker request with pre-tokenized prompt and ALL multimodal inputs
worker_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(
prompt_token_ids=raw_request["token_ids"] # Pre-tokenized by Rust!
),
sampling_params=sampling_params,
request_id=request_id,
multimodal_inputs=multimodal_groups, # ALL images at once
)
logger.info(
f"[{request_id}] Sending request with {len(multimodal_groups)} "
f"multimodal item(s) to PD worker (ECConnector consumer)..."
)
# Send single request to PD worker with ALL images
response_generator = await self.pd_worker_client.round_robin( # type: ignore[call-arg]
worker_request.model_dump_json(), context=context
)
# Stream responses back to client (reuse base class method)
async for output in self._generate_responses(response_generator):
yield output
logger.info(
f"[{request_id}] Completed processing all {len(multimodal_groups)} item(s)"
)
...@@ -21,6 +21,7 @@ from dynamo.vllm.multimodal_utils.model import ( ...@@ -21,6 +21,7 @@ from dynamo.vllm.multimodal_utils.model import (
) )
from dynamo.vllm.multimodal_utils.prefill_worker_utils import ( from dynamo.vllm.multimodal_utils.prefill_worker_utils import (
accumulate_embeddings, accumulate_embeddings,
fetch_embeddings_from_encode_workers,
load_embeddings, load_embeddings,
) )
from dynamo.vllm.multimodal_utils.protocol import ( from dynamo.vllm.multimodal_utils.protocol import (
...@@ -29,8 +30,6 @@ from dynamo.vllm.multimodal_utils.protocol import ( ...@@ -29,8 +30,6 @@ from dynamo.vllm.multimodal_utils.protocol import (
MultiModalRequest, MultiModalRequest,
MyRequestOutput, MyRequestOutput,
PatchedTokensPrompt, PatchedTokensPrompt,
VLLMNativeEncoderRequest,
VLLMNativeEncoderResponse,
vLLMMultimodalRequest, vLLMMultimodalRequest,
) )
...@@ -53,8 +52,7 @@ __all__ = [ ...@@ -53,8 +52,7 @@ __all__ = [
"MultiModalRequest", "MultiModalRequest",
"MyRequestOutput", "MyRequestOutput",
"vLLMMultimodalRequest", "vLLMMultimodalRequest",
"VLLMNativeEncoderRequest",
"VLLMNativeEncoderResponse",
"accumulate_embeddings", "accumulate_embeddings",
"fetch_embeddings_from_encode_workers",
"load_embeddings", "load_embeddings",
] ]
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
import hashlib import hashlib
import json
import logging import logging
import os import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from vllm.config import ECTransferConfig
from .model import SupportedModels, is_model_supported, is_qwen_vl_model from .model import SupportedModels, is_model_supported, is_qwen_vl_model
...@@ -160,51 +158,3 @@ def get_encoder_components( ...@@ -160,51 +158,3 @@ def get_encoder_components(
else: else:
raise NotImplementedError(f"Model not supported: {model_name}") raise NotImplementedError(f"Model not supported: {model_name}")
def create_ec_transfer_config(
engine_id: str,
ec_role: str,
ec_connector_backend: str = "ECExampleConnector",
ec_storage_path: Optional[str] = None,
ec_extra_config: Optional[str] = None,
) -> ECTransferConfig:
"""
Create ECTransferConfig for vLLM encoder disaggregation.
Args:
engine_id: Unique identifier for this engine instance
ec_role: Role of this instance - "ec_producer" (encoder) or "ec_consumer" (PD worker)
ec_connector_backend: ECConnector implementation class name
ec_storage_path: Storage path for disk-based connectors
ec_extra_config: Additional connector config as JSON string
Returns:
ECTransferConfig configured for the specified role
"""
# Parse extra config if provided
extra_config: Dict[str, Any] = {}
if ec_extra_config:
try:
extra_config = json.loads(ec_extra_config)
logger.debug(f"Parsed ec_extra_config: {extra_config}")
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in --ec-extra-config: {e}")
# Add storage path to config if provided
if ec_storage_path:
extra_config["shared_storage_path"] = ec_storage_path
else:
raise ValueError("ec_storage_path is not provided")
logger.info(
f"Creating ECTransferConfig: engine_id={engine_id}, role={ec_role}, "
f"backend={ec_connector_backend}, config={extra_config}"
)
return ECTransferConfig(
engine_id=engine_id,
ec_role=ec_role,
ec_connector=ec_connector_backend,
ec_connector_extra_config=extra_config,
)
...@@ -3,18 +3,28 @@ ...@@ -3,18 +3,28 @@
import logging import logging
import os import os
from typing import Any, Dict from typing import Any, Dict, List
import safetensors import safetensors
import torch import torch
from vllm.sampling_params import SamplingParams as VllmSamplingParams
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.runtime import Client
from .model import construct_mm_data from .model import construct_mm_data
from .protocol import MultiModalGroup from .protocol import (
MultiModalGroup,
MultiModalInput,
PatchedTokensPrompt,
vLLMMultimodalRequest,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
VIDEO_URL_KEY = "video_url"
TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1)) TRANSFER_LOCAL = int(os.getenv("TRANSFER_LOCAL", 1))
...@@ -115,3 +125,62 @@ def accumulate_embeddings( ...@@ -115,3 +125,62 @@ def accumulate_embeddings(
multi_modal_data["image"] = torch.cat( multi_modal_data["image"] = torch.cat(
(multi_modal_data["image"], mm_data["image"]) (multi_modal_data["image"], mm_data["image"])
) )
async def fetch_embeddings_from_encode_workers(
encode_worker_client: Client,
image_urls: List[str],
request_id: str,
) -> List[MultiModalGroup]:
"""Fan out image URLs to encode workers and collect embedding results.
Splits image URLs into batches based on available encode worker count,
dispatches via round-robin, and collects the resulting MultiModalGroups
containing pre-computed embeddings.
"""
encode_worker_count = len(encode_worker_client.instance_ids())
if encode_worker_count == 0:
raise RuntimeError("No encode workers available to process multimodal input")
encode_batch_size = max(1, len(image_urls) // encode_worker_count)
encode_request = vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[]),
sampling_params=VllmSamplingParams(),
request_id=request_id,
multimodal_inputs=[],
)
batch: List[MultiModalGroup] = []
encode_response_streams = []
for url in image_urls:
multimodal_input = MultiModalInput()
multimodal_input.image_url = url
batch.append(MultiModalGroup(multimodal_input=multimodal_input))
if len(batch) >= encode_batch_size:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
batch = []
# Flush remaining
if batch:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
# Collect results
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
return multimodal_groups
...@@ -182,28 +182,6 @@ class vLLMMultimodalRequest(vLLMGenerateRequest): ...@@ -182,28 +182,6 @@ class vLLMMultimodalRequest(vLLMGenerateRequest):
embeddings_shape: Optional[List[int]] = None embeddings_shape: Optional[List[int]] = None
class VLLMNativeEncoderRequest(BaseModel):
"""Request for vLLM-native encoder worker using ECConnector"""
request_id: str
token_ids: List[
int
] # Pre-tokenized prompt with placeholder tokens (for TokensPrompt)
multimodal_inputs: List[MultiModalGroup] = Field(default_factory=list)
modality: Optional[
Literal["image", "video", "audio"]
] = None # Can be inferred from inputs
class VLLMNativeEncoderResponse(BaseModel):
"""Response from vLLM-native encoder worker (ECConnector mode)"""
request_id: str
mm_hash: str # vLLM's multimodal hash identifier
modality: str # "image", "video", "audio"
connector_metadata: dict[str, Any] # ECConnector config info for PD workers
class MyRequestOutput(BaseModel): class MyRequestOutput(BaseModel):
""" """
RequestOutput from vLLM is not serializable by default RequestOutput from vLLM is not serializable by default
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Unit tests for MultimodalPDWorkerHandler.__init__.""" """Unit tests for MultimodalPDWorkerHandler."""
from unittest.mock import MagicMock, patch import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
...@@ -11,6 +12,13 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import ( ...@@ -11,6 +12,13 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod from dynamo.vllm.multimodal_handlers import multimodal_pd_worker_handler as mod
from dynamo.vllm.multimodal_utils.protocol import (
MultiModalGroup,
MultiModalInput,
MyRequestOutput,
PatchedTokensPrompt,
vLLMMultimodalRequest,
)
pytestmark = [ pytestmark = [
pytest.mark.pre_merge, pytest.mark.pre_merge,
...@@ -20,13 +28,16 @@ pytestmark = [ ...@@ -20,13 +28,16 @@ pytestmark = [
] ]
# ── Helpers ──────────────────────────────────────────────────────────
def _make_config( def _make_config(
model: str = "test-model", model: str = "test-model",
is_prefill_worker: bool = False, is_prefill_worker: bool = False,
enable_multimodal: bool = True, enable_multimodal: bool = True,
multimodal_embedding_cache_capacity_gb: float = 0, multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock: ) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.__init__.""" """Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
config = MagicMock() config = MagicMock()
config.model = model config.model = model
config.is_prefill_worker = is_prefill_worker config.is_prefill_worker = is_prefill_worker
...@@ -40,27 +51,179 @@ def _make_config( ...@@ -40,27 +51,179 @@ def _make_config(
return config return config
class TestMultimodalPDWorkerHandlerInit: def _make_handler(
"""Tests for MultimodalPDWorkerHandler.__init__ focusing on embedding cache.""" config: MagicMock | None = None,
encode_worker_client: MagicMock | None = None,
decode_worker_client: MagicMock | None = None,
) -> mod.MultimodalPDWorkerHandler:
"""Construct a handler with BaseWorkerHandler.__init__ bypassed."""
if config is None:
config = _make_config()
with patch.object(mod.BaseWorkerHandler, "__init__", return_value=None):
return mod.MultimodalPDWorkerHandler(
runtime=MagicMock(),
component=MagicMock(),
engine_client=MagicMock(),
config=config,
encode_worker_client=encode_worker_client,
decode_worker_client=decode_worker_client,
)
def test_init_with_embedding_cache(self):
"""When capacity > 0, a MultimodalEmbeddingCacheManager is created with correct byte size."""
capacity_gb = 0.1
config = _make_config(multimodal_embedding_cache_capacity_gb=capacity_gb)
with (
patch.object(mod.BaseWorkerHandler, "__init__", return_value=None),
patch.object(mod, "ImageLoader", new_callable=MagicMock),
):
handler = mod.MultimodalPDWorkerHandler(
runtime=MagicMock(),
component=MagicMock(),
engine_client=MagicMock(),
config=config,
)
def _make_raw_frontend_request(image_urls: list[str] | None = None) -> dict:
"""Build a raw dict that mimics what the Rust frontend sends."""
mm_data = None
if image_urls:
mm_data = {
"image_url": [{"Url": url} for url in image_urls],
}
return {
"token_ids": [1, 2, 3],
"multi_modal_data": mm_data,
"sampling_options": {},
"stop_conditions": {},
"output_options": {},
}
def _make_vllm_request(request_id: str = "req-1") -> vLLMMultimodalRequest:
"""Build a minimal vLLMMultimodalRequest."""
from vllm.sampling_params import SamplingParams
return vLLMMultimodalRequest(
engine_prompt=PatchedTokensPrompt(prompt_token_ids=[1, 2, 3]),
sampling_params=SamplingParams(),
request_id=request_id,
multimodal_inputs=[],
)
def _make_engine_response(request_id: str = "req-1", finished: bool = True):
"""Create a mock engine response with the fields _serialize_response needs."""
resp = MagicMock()
resp.request_id = request_id
resp.prompt = "test"
resp.prompt_token_ids = [1, 2, 3]
resp.prompt_logprobs = None
resp.outputs = []
resp.finished = finished
resp.metrics = None
resp.kv_transfer_params = {"do_remote_decode": False}
return resp
# ── Tests ────────────────────────────────────────────────────────────
class TestInit:
def test_embedding_cache_created_when_capacity_set(self):
capacity_gb = 0.1
handler = _make_handler(
config=_make_config(multimodal_embedding_cache_capacity_gb=capacity_gb)
)
assert isinstance( assert isinstance(
handler.embedding_cache_manager, MultimodalEmbeddingCacheManager handler.embedding_cache_manager, MultimodalEmbeddingCacheManager
) )
expected_bytes = int(capacity_gb * 1024**3) expected_bytes = int(capacity_gb * 1024**3)
assert handler.embedding_cache_manager._capacity_bytes == expected_bytes assert handler.embedding_cache_manager._capacity_bytes == expected_bytes
class TestBuildRequestFromFrontend:
@pytest.mark.asyncio
async def test_with_encode_worker_calls_fetch(self):
"""With encode client -> delegates to fetch_embeddings_from_encode_workers."""
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
handler.default_sampling_params = {}
fake_group = MultiModalGroup(multimodal_input=MultiModalInput())
with patch.object(
mod,
"fetch_embeddings_from_encode_workers",
new_callable=AsyncMock,
return_value=[fake_group],
) as mock_fetch:
raw = _make_raw_frontend_request(image_urls=["http://img.png"])
result = await handler._build_request_from_frontend(raw)
mock_fetch.assert_awaited_once()
assert result.multimodal_inputs == [fake_group]
class TestGenerateAgg:
@pytest.mark.asyncio
async def test_streams_serialized_responses(self):
"""_generate_agg yields dicts formatted by _format_engine_output."""
handler = _make_handler()
request = _make_vllm_request()
engine_resp = _make_engine_response()
# Add a proper output so we exercise the happy path
output = MagicMock()
output.token_ids = [10, 11]
output.finish_reason = "stop"
output.stop_reason = None
engine_resp.outputs = [output]
async def fake_generate(**kwargs):
yield engine_resp
handler.engine_client = MagicMock()
handler.engine_client.generate = fake_generate
chunks = []
async for chunk in handler._generate_agg(request, {"image": []}):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0]["token_ids"] == [10, 11]
assert chunks[0]["finish_reason"] == "stop"
class TestGenerateDisagg:
@pytest.mark.asyncio
async def test_prefills_then_forwards_to_decode(self):
"""_generate_disagg prefills locally, then round-robins to decode worker."""
config = _make_config(model="test-model", is_prefill_worker=True)
decode_client = MagicMock()
handler = _make_handler(config=config, decode_worker_client=decode_client)
handler.engine_client = MagicMock()
# Mock prefill engine response
prefill_resp = _make_engine_response()
prefill_resp.kv_transfer_params = {"block_ids": [0, 1]}
async def fake_generate(**kwargs):
yield prefill_resp
handler.engine_client.generate = fake_generate
# Mock decode worker response
decode_output = MyRequestOutput(
request_id="req-1",
prompt="test",
prompt_token_ids=[1, 2, 3],
outputs=[],
finished=True,
kv_transfer_params={"block_ids": [0, 1]},
)
decode_resp = MagicMock()
decode_resp.data.return_value = decode_output.model_dump_json()
async def fake_round_robin(payload):
async def _stream():
yield decode_resp
return _stream()
decode_client.round_robin = fake_round_robin
request = _make_vllm_request()
chunks = []
async for chunk in handler._generate_disagg(request, {"image": []}):
chunks.append(chunk)
assert len(chunks) == 1
parsed = json.loads(chunks[0])
assert parsed["request_id"] == "req-1"
assert parsed["finished"] is True
...@@ -50,7 +50,6 @@ vLLM supports all multimodal deployment patterns. See [Architecture Patterns](RE ...@@ -50,7 +50,6 @@ vLLM supports all multimodal deployment patterns. See [Architecture Patterns](RE
| Prefill Worker | `--multimodal-worker --is-prefill-worker` | Prefill only | | Prefill Worker | `--multimodal-worker --is-prefill-worker` | Prefill only |
| Decode Worker | `--multimodal-decode-worker` | Decode only | | Decode Worker | `--multimodal-decode-worker` | Decode only |
| Encode+Prefill Worker | `--multimodal-encode-prefill-worker --is-prefill-worker` | Combined (Llama 4) | | Encode+Prefill Worker | `--multimodal-encode-prefill-worker --is-prefill-worker` | Combined (Llama 4) |
| vLLM Native Encoder | `--vllm-native-encoder-worker` | vLLM-native encoding with ECConnector |
## Use the Latest Release ## Use the Latest Release
......
...@@ -6,8 +6,7 @@ trap 'echo Cleaning up...; kill 0' EXIT ...@@ -6,8 +6,7 @@ trap 'echo Cleaning up...; kill 0' EXIT
# Default values # Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf" MODEL_NAME="llava-hf/llava-1.5-7b-hf"
EC_STORAGE_PATH="/tmp/dynamo_ec_cache" EC_CONNECTOR_BACKEND="DynamoEcConnector"
EC_CONNECTOR_BACKEND="ECExampleConnector"
# Parse command line arguments # Parse command line arguments
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
...@@ -16,35 +15,22 @@ while [[ $# -gt 0 ]]; do ...@@ -16,35 +15,22 @@ while [[ $# -gt 0 ]]; do
MODEL_NAME=$2 MODEL_NAME=$2
shift 2 shift 2
;; ;;
--ec-storage-path)
EC_STORAGE_PATH=$2
shift 2
;;
--ec-connector-backend)
EC_CONNECTOR_BACKEND=$2
shift 2
;;
-h|--help) -h|--help)
echo "Usage: $0 [OPTIONS]" echo "Usage: $0 [OPTIONS]"
echo "" echo ""
echo "Aggregated multimodal serving with vLLM-native encoder (ECConnector mode)" echo "Aggregated multimodal serving with ECConnector (ec_both mode)"
echo "" echo ""
echo "This script launches:" echo "This script launches:"
echo " - Frontend server" echo " - Frontend server"
echo " - Processor component (uses pre-tokenized input with ModelInput.Tokens)" echo " - Aggregated multimodal worker (ec_both: produces and consumes encoder cache)"
echo " - vLLM-native encoder worker (producer using ECConnector)"
echo " - Multimodal worker (consumer using ECConnector, aggregated P+D)"
echo "" echo ""
echo "Options:" echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)" echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " --ec-storage-path <path> Path for ECConnector storage (default: $EC_STORAGE_PATH)" echo " -h, --help Show this help message"
echo " --ec-connector-backend <backend> ECConnector backend class (default: $EC_CONNECTOR_BACKEND)"
echo " -h, --help Show this help message"
echo "" echo ""
echo "Examples:" echo "Examples:"
echo " $0" echo " $0"
echo " $0 --model llava-hf/llava-1.5-7b-hf" echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo " $0 --ec-storage-path /shared/encoder-cache"
echo "" echo ""
exit 0 exit 0
;; ;;
...@@ -56,54 +42,34 @@ while [[ $# -gt 0 ]]; do ...@@ -56,54 +42,34 @@ while [[ $# -gt 0 ]]; do
esac esac
done done
# Create storage directory if it doesn't exist
mkdir -p "$EC_STORAGE_PATH"
echo "==================================================" echo "=================================================="
echo "Aggregated Multimodal Serving (vLLM-Native Encoder with ECConnector)" echo "Aggregated Multimodal Serving (ECConnector ec_both)"
echo "==================================================" echo "=================================================="
echo "Model: $MODEL_NAME" echo "Model: $MODEL_NAME"
echo "ECConnector Backend: $EC_CONNECTOR_BACKEND" echo "ECConnector Backend: $EC_CONNECTOR_BACKEND"
echo "Storage Path: $EC_STORAGE_PATH"
echo "==================================================" echo "=================================================="
# Start frontend # GPU assignment (override via environment variable)
echo "Starting frontend..." DYN_WORKER_GPU=${DYN_WORKER_GPU:-0}
python -m dynamo.frontend &
# Start EC Processor (uses pre-tokenized input with placeholder tokens) # GPU memory utilization
echo "Starting EC Processor..." DYN_GPU_MEM=${DYN_GPU_MEM:-0.85}
python -m dynamo.vllm \
--ec-processor \
--enable-multimodal \
--model $MODEL_NAME &
# Start vLLM-native encoder worker (ECConnector producer) # Start frontend
echo "Starting vLLM-native encoder worker (ECConnector producer) on GPU 0..." echo "Starting frontend..."
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm \ python -m dynamo.frontend &
--vllm-native-encoder-worker \
--enable-multimodal \
--model $MODEL_NAME \
--ec-connector-backend $EC_CONNECTOR_BACKEND \
--ec-storage-path $EC_STORAGE_PATH \
--connector none \
--enforce-eager \
--max-num-batched-tokens 114688 \
--no-enable-prefix-caching &
# Start aggregated multimodal worker (ECConnector consumer, P+D combined) # Start aggregated multimodal worker (ec_both: produces and consumes encoder cache)
echo "Starting aggregated multimodal worker (ECConnector consumer) on GPU 1..." echo "Starting aggregated multimodal worker (ec_both) on GPU $DYN_WORKER_GPU (mem: $DYN_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm \ CUDA_VISIBLE_DEVICES=$DYN_WORKER_GPU python -m dynamo.vllm \
--multimodal-worker \ --multimodal-worker \
--enable-multimodal \ --enable-multimodal \
--model $MODEL_NAME \ --model $MODEL_NAME \
--ec-consumer-mode \
--ec-connector-backend $EC_CONNECTOR_BACKEND \
--ec-storage-path $EC_STORAGE_PATH \
--enable-mm-embeds \ --enable-mm-embeds \
--connector none \ --connector none \
--enforce-eager & --enforce-eager \
--gpu-memory-utilization $DYN_GPU_MEM \
--ec-transfer-config "{\"ec_connector\":\"$EC_CONNECTOR_BACKEND\",\"ec_role\":\"ec_both\"}" &
# Wait for all background processes to complete # Wait for all background processes to complete
wait wait
...@@ -15,7 +15,7 @@ set -e ...@@ -15,7 +15,7 @@ set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# Default values # Default values
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct" MODEL_NAME="Qwen/Qwen3-VL-30B-A3B-Instruct-FP8"
# Parse command line arguments # Parse command line arguments
# Extra arguments are passed through to the vLLM worker # Extra arguments are passed through to the vLLM worker
...@@ -53,7 +53,7 @@ export DYN_REQUEST_PLANE=tcp ...@@ -53,7 +53,7 @@ export DYN_REQUEST_PLANE=tcp
python -m dynamo.frontend & python -m dynamo.frontend &
# Configure GPU memory optimization for specific models (if no extra args override) # Configure GPU memory optimization for specific models (if no extra args override)
MODEL_SPECIFIC_ARGS="" MODEL_SPECIFIC_ARGS="--gpu-memory-utilization 0.85 --max-model-len 16384"
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
MODEL_SPECIFIC_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096" MODEL_SPECIFIC_ARGS="--gpu-memory-utilization 0.85 --max-model-len 4096"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
...@@ -67,6 +67,7 @@ fi ...@@ -67,6 +67,7 @@ fi
# --enforce-eager: Quick deployment (remove for production) # --enforce-eager: Quick deployment (remove for production)
# --connector none: No KV transfer needed for aggregated serving # --connector none: No KV transfer needed for aggregated serving
# Extra args from command line come last to allow overrides # Extra args from command line come last to allow overrides
CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} \
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \ DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT:-8081} \
python -m dynamo.vllm --enable-multimodal --model $MODEL_NAME --connector none $MODEL_SPECIFIC_ARGS "${EXTRA_ARGS[@]}" python -m dynamo.vllm --enable-multimodal --model $MODEL_NAME --connector none $MODEL_SPECIFIC_ARGS "${EXTRA_ARGS[@]}"
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# EPD (Encode-Prefill-Decode) multimodal deployment
#
# Architecture: 3-component disaggregation
# - Processor: Python-based preprocessor (bypasses Rust OpenAIPreprocessor)
# - Encode Worker: Dedicated vision encoder that extracts image embeddings
# - PD Worker: Standard prefill/decode worker that receives embeddings via NIXL
#
# Benefits: Decouples encoding from inference, enables independent scaling
# For standard single-worker deployment, see agg_multimodal.sh
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
SINGLE_GPU=false
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--single-gpu)
SINGLE_GPU=true
shift
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --single-gpu Run both encode and PD workers on GPU 0 (for pre-merge CI)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Start frontend (HTTP endpoint)
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend &
# Set max model length based on model name
MAX_MODEL_LEN=""
if [[ "$MODEL_NAME" == "Qwen/Qwen2.5-VL-7B-Instruct" ]]; then
MAX_MODEL_LEN="4096"
elif [[ "$MODEL_NAME" == "llava-hf/llava-1.5-7b-hf" ]]; then
MAX_MODEL_LEN="2048"
else
MAX_MODEL_LEN="30426"
fi
# Set GPU memory utilization and model length based on deployment mode
# Single-GPU mode: Both workers share GPU 0, so use reduced memory settings
# Multi-GPU mode: Each worker gets its own GPU, so use higher memory settings
EXTRA_ARGS=""
if [[ "$SINGLE_GPU" == "true" ]]; then
EXTRA_ARGS="--gpu-memory-utilization 0.4 --enforce-eager --max-model-len $MAX_MODEL_LEN"
else
# Multi-GPU mode: standard memory settings
EXTRA_ARGS="--gpu-memory-utilization 0.85 --max-model-len $MAX_MODEL_LEN"
fi
# Start processor (Python-based preprocessing, handles prompt templating)
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
# run E/P/D workers
# Use single GPU (GPU 0) for pre-merge CI, otherwise use GPU 0 for encode and GPU 1 for PD
if [[ "$SINGLE_GPU" == "true" ]]; then
# Single GPU mode: both workers share GPU 0 with reduced memory
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS &
# Now that encode worker and PD worker are vLLM engine, need to ensure encode worker and PD worker are not initialized concurrently
# on the same GPU to avoid influencing each other's startup process (checks and allocations).
sleep 60
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS &
else
CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS &
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS &
fi
# Wait for all background processes to complete
wait
...@@ -52,25 +52,31 @@ echo "Starting frontend..." ...@@ -52,25 +52,31 @@ echo "Starting frontend..."
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000) # dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python -m dynamo.frontend & python -m dynamo.frontend &
# Start processor
echo "Starting processor..."
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME &
EXTRA_ARGS="" EXTRA_ARGS=""
# GPU assignments (override via environment variables)
DYN_ENCODE_WORKER_GPU=${DYN_ENCODE_WORKER_GPU:-0}
DYN_PREFILL_WORKER_GPU=${DYN_PREFILL_WORKER_GPU:-1}
DYN_DECODE_WORKER_GPU=${DYN_DECODE_WORKER_GPU:-2}
# GPU memory utilization for workers
DYN_ENCODE_GPU_MEM=${DYN_ENCODE_GPU_MEM:-0.9}
DYN_PREFILL_GPU_MEM=${DYN_PREFILL_GPU_MEM:-0.9}
DYN_DECODE_GPU_MEM=${DYN_DECODE_GPU_MEM:-0.9}
# Start encode worker # Start encode worker
echo "Starting encode worker on GPU 0..." echo "Starting encode worker on GPU $DYN_ENCODE_WORKER_GPU (GPU mem: $DYN_ENCODE_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=0 python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' & VLLM_NIXL_SIDE_CHANNEL_PORT=20097 CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU python -m dynamo.vllm --multimodal-encode-worker --enable-multimodal --model $MODEL_NAME --gpu-memory-utilization $DYN_ENCODE_GPU_MEM $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20080"}' &
# Start prefill worker # Start prefill worker (also handles encode routing via --route-to-encoder)
echo "Starting prefill worker on GPU 1..." echo "Starting prefill worker on GPU $DYN_PREFILL_WORKER_GPU (GPU mem: $DYN_PREFILL_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \ VLLM_NIXL_SIDE_CHANNEL_PORT=20098 \
CUDA_VISIBLE_DEVICES=1 python -m dynamo.vllm --multimodal-worker --is-prefill-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' & CUDA_VISIBLE_DEVICES=$DYN_PREFILL_WORKER_GPU python -m dynamo.vllm --route-to-encoder --is-prefill-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_PREFILL_GPU_MEM $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20081"}' &
# Start decode worker # Start decode worker
echo "Starting decode worker on GPU 2..." echo "Starting decode worker on GPU $DYN_DECODE_WORKER_GPU (GPU mem: $DYN_DECODE_GPU_MEM)..."
VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \ VLLM_NIXL_SIDE_CHANNEL_PORT=20099 \
CUDA_VISIBLE_DEVICES=2 python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' & CUDA_VISIBLE_DEVICES=$DYN_DECODE_WORKER_GPU python -m dynamo.vllm --multimodal-decode-worker --enable-multimodal --enable-mm-embeds --model $MODEL_NAME --gpu-memory-utilization $DYN_DECODE_GPU_MEM $EXTRA_ARGS --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:20082"}' &
echo "==================================================" echo "=================================================="
echo "All components started. Waiting for initialization..." echo "All components started. Waiting for initialization..."
......
...@@ -63,7 +63,7 @@ if [[ $HEAD_NODE -eq 1 ]]; then ...@@ -63,7 +63,7 @@ if [[ $HEAD_NODE -eq 1 ]]; then
# run processor (CPU-only to avoid competing for GPU memory with workers) # run processor (CPU-only to avoid competing for GPU memory with workers)
CUDA_VISIBLE_DEVICES="" \ CUDA_VISIBLE_DEVICES="" \
python -m dynamo.vllm --multimodal-processor --enable-multimodal --model $MODEL_NAME & python -m dynamo.vllm --route-to-encoder --enable-multimodal --model $MODEL_NAME &
# Prefill worker handles prompt processing and image encoding # Prefill worker handles prompt processing and image encoding
# Uses all 8 GPUs for tensor-parallel # Uses all 8 GPUs for tensor-parallel
......
...@@ -276,34 +276,37 @@ vllm_configs = { ...@@ -276,34 +276,37 @@ vllm_configs = {
completion_payload_default(), completion_payload_default(),
], ],
), ),
"multimodal_agg_qwen2vl_2b_epd": VLLMConfig( # The original script is misleading agg_multimodal_epd.sh is actually a disagg
name="multimodal_agg_qwen2vl_2b_epd", # case which uses disgg encoder. We are bringing this test back shortly
directory=vllm_dir, # TODO(qiwa): enable this in https://github.com/ai-dynamo/dynamo/pull/6061/
script_name="agg_multimodal_epd.sh", # "multimodal_agg_qwen2vl_2b_epd": VLLMConfig(
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge], # name="multimodal_agg_qwen2vl_2b_epd",
model="Qwen/Qwen2-VL-2B-Instruct", # directory=vllm_dir,
script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"], # script_name="agg_multimodal_epd.sh",
request_payloads=[ # marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
chat_payload( # model="Qwen/Qwen2-VL-2B-Instruct",
[ # script_args=["--model", "Qwen/Qwen2-VL-2B-Instruct", "--single-gpu"],
{ # request_payloads=[
"type": "text", # chat_payload(
"text": "What colors are in the following image? Respond only with the colors.", # [
}, # {
{ # "type": "text",
"type": "image_url", # "text": "What colors are in the following image? Respond only with the colors.",
"image_url": {"url": MULTIMODAL_IMG_URL}, # },
}, # {
], # "type": "image_url",
repeat_count=1, # "image_url": {"url": MULTIMODAL_IMG_URL},
# With proper prompt templating, the model actually only returns "green", # },
# verified behavior with native vLLM. # ],
expected_response=["green"], # repeat_count=1,
temperature=0.0, # # With proper prompt templating, the model actually only returns "green",
max_tokens=100, # # verified behavior with native vLLM.
) # expected_response=["green"],
], # temperature=0.0,
), # max_tokens=100,
# )
# ],
# ),
"multimodal_agg_frontend_decoding": VLLMConfig( "multimodal_agg_frontend_decoding": VLLMConfig(
name="multimodal_agg_frontend_decoding", name="multimodal_agg_frontend_decoding",
directory=vllm_dir, directory=vllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment