Unverified Commit 4d0380d5 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

refactor: introduce worker factory in vLLM multimodal (#6060)

parent 638d8e68
......@@ -160,10 +160,7 @@ def update_dynamo_config_with_engine(
if dynamo_config.route_to_encoder:
dynamo_config.component = "processor"
dynamo_config.endpoint = "generate"
elif (
dynamo_config.multimodal_encode_worker
or dynamo_config.multimodal_encode_prefill_worker
):
elif dynamo_config.multimodal_encode_worker:
dynamo_config.component = "encoder"
dynamo_config.endpoint = "generate"
elif dynamo_config.multimodal_decode_worker:
......
......@@ -88,13 +88,6 @@ class DynamoVllmArgGroup(ArgGroup):
default=False,
help="Run as multimodal decode worker in disaggregated mode.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-encode-prefill-worker",
env_var="DYN_VLLM_MULTIMODAL_ENCODE_PREFILL_WORKER",
default=False,
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4).",
)
add_negatable_bool_argument(
g,
flag_name="--enable-multimodal",
......@@ -170,7 +163,6 @@ class DynamoVllmConfig(ConfigBase):
multimodal_encode_worker: bool
multimodal_worker: bool
multimodal_decode_worker: bool
multimodal_encode_prefill_worker: bool
enable_multimodal: bool
mm_prompt_template: str
frontend_decoding: bool
......@@ -206,7 +198,6 @@ class DynamoVllmConfig(ConfigBase):
bool(self.multimodal_encode_worker),
bool(self.multimodal_worker),
bool(self.multimodal_decode_worker),
bool(self.multimodal_encode_prefill_worker),
]
)
......@@ -215,7 +206,7 @@ class DynamoVllmConfig(ConfigBase):
if self._count_multimodal_roles() > 1:
raise ValueError(
"Use only one of --multimodal-encode-worker, --multimodal-worker, "
"--multimodal-decode-worker, --multimodal-encode-prefill-worker"
"--multimodal-decode-worker"
)
def _validate_multimodal_requires_flag(self) -> None:
......
......@@ -45,11 +45,7 @@ except ImportError:
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
)
from dynamo.vllm.worker_factory import WorkerFactory
from .args import Config, parse_args
from .chrek import get_checkpoint_config
......@@ -148,18 +144,17 @@ async def worker():
)
# Route to appropriate initialization based on config flags
if config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_encode_worker completed")
elif (
config.multimodal_worker
or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker
):
await init_multimodal_worker(
if WorkerFactory.handles(config):
# Create worker factory with setup functions
factory = WorkerFactory(
setup_vllm_engine_fn=setup_vllm_engine,
setup_kv_event_publisher_fn=setup_kv_event_publisher,
register_vllm_model_fn=register_vllm_model,
)
await factory.create(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
logger.debug("init_multimodal_worker completed")
logger.debug("multimodal worker completed")
elif config.omni:
await init_omni(runtime, config, shutdown_event)
logger.debug("init_omni completed")
......@@ -924,165 +919,6 @@ def get_engine_cache_info(engine: AsyncLLM):
raise
async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal encode worker component"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
handler = EncodeWorkerHandler(
config.engine_args,
)
await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[
(prometheus_names.labels.MODEL, config.model),
(prometheus_names.labels.MODEL_NAME, config.model),
],
),
)
except Exception as e:
logger.error(f"Failed to serve encode worker endpoint: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_worker(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
):
"""
Initialize multimodal worker component.
Supports three modes:
1. --multimodal-worker: Prefill+decode worker for multimodal LLM; can route
to a separate encoder (--route-to-encoder) for embeddings. Runs
aggregated (P+D) or disaggregated (P→D).
2. --multimodal-decode-worker: Decode-only worker in disaggregated (P→D)
mode.
3. --multimodal-encode-prefill-worker: Unified encode+prefill+decode in one
worker for models with integrated image encoding (e.g., Llama 4).
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = pre_created_engine
else:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = setup_vllm_engine(config)
# Set up encode worker client when routing to encoder is enabled
# (PD worker handles encode routing directly instead of a separate processor)
encode_worker_client = None
if config.route_to_encoder:
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
# Set up decode worker client for disaggregated mode
decode_worker_client = None
if config.is_prefill_worker:
# Prefill worker needs to connect to decode worker
decode_worker_client = (
await runtime.namespace(config.namespace)
.component("decoder")
.endpoint("generate")
.client()
)
await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode")
# Choose handler based on worker type
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config, shutdown_event
)
else:
handler = MultimodalPDWorkerHandler(
runtime,
component,
engine_client,
config,
encode_worker_client,
decode_worker_client,
shutdown_event,
)
handler.add_temp_dir(prometheus_temp_dir)
await handler.async_init(runtime)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
# Register model with the frontend so it can route requests
model_type = parse_endpoint_types(config.endpoint_types)
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
await register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
metrics_labels = [
(prometheus_names.labels.MODEL, config.served_model_name or config.model),
(prometheus_names.labels.MODEL_NAME, config.served_model_name or config.model),
]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=metrics_labels,
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for worker_factory.py"""
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
from dynamo.vllm.worker_factory import EngineSetupResult, WorkerFactory
def _make_config(**overrides) -> Mock:
"""Create a mock Config with all multimodal flags defaulting to False."""
defaults = {
"multimodal_encode_worker": False,
"multimodal_worker": False,
"multimodal_decode_worker": False,
"omni": False,
"is_prefill_worker": False,
}
defaults.update(overrides)
return Mock(**defaults)
class TestHandles:
"""Test WorkerFactory.handles() config detection."""
def test_multimodal_encode_worker(self) -> None:
config = _make_config(multimodal_encode_worker=True)
assert WorkerFactory.handles(config)
def test_multimodal_worker(self) -> None:
config = _make_config(multimodal_worker=True)
assert WorkerFactory.handles(config)
def test_multimodal_decode_worker(self) -> None:
config = _make_config(multimodal_decode_worker=True)
assert WorkerFactory.handles(config)
def test_no_multimodal_flags(self) -> None:
config = _make_config()
assert not WorkerFactory.handles(config)
def test_omni_not_handled(self) -> None:
config = _make_config(omni=True)
assert not WorkerFactory.handles(config)
def test_prefill_only_not_handled(self) -> None:
config = _make_config(is_prefill_worker=True)
assert not WorkerFactory.handles(config)
class TestCreate:
"""Test WorkerFactory.create() routing."""
@pytest.fixture
def factory(self) -> WorkerFactory:
factory = WorkerFactory(
setup_vllm_engine_fn=Mock(),
setup_kv_event_publisher_fn=Mock(),
register_vllm_model_fn=AsyncMock(),
)
factory._create_multimodal_encode_worker = AsyncMock() # type: ignore[assignment]
factory._create_multimodal_worker = AsyncMock() # type: ignore[assignment]
return factory
@pytest.mark.asyncio
async def test_routes_to_multimodal_encode(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_encode_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
factory._create_multimodal_encode_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_routes_to_multimodal_worker(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_routes_multimodal_decode_worker(
self, factory: WorkerFactory
) -> None:
config = _make_config(multimodal_decode_worker=True)
shutdown_event = asyncio.Event()
await factory.create(Mock(), config, shutdown_event)
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_passes_pre_created_engine(self, factory: WorkerFactory) -> None:
config = _make_config(multimodal_worker=True)
runtime = Mock()
shutdown_event = asyncio.Event()
pre_created_engine: EngineSetupResult = (
Mock(),
Mock(),
Mock(),
"/tmp/prometheus",
)
await factory.create(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
factory._create_multimodal_worker.assert_called_once_with( # type: ignore[union-attr]
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
@pytest.mark.asyncio
async def test_raises_when_no_multimodal_flag(self, factory: WorkerFactory) -> None:
config = _make_config()
with pytest.raises(ValueError, match="no multimodal worker type set"):
await factory.create(Mock(), config, asyncio.Event())
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Worker initialization factory for vLLM workers."""
import asyncio
import logging
from collections.abc import Awaitable, Callable
from typing import Any, Optional
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime
from .args import Config
from .multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
MultimodalPDWorkerHandler,
)
logger = logging.getLogger(__name__)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir)
EngineSetupResult = tuple[Any, Any, Any, Any]
SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]]
RegisterVllmModelFn = Callable[..., Awaitable[None]]
class WorkerFactory:
"""Factory for creating and initializing multimodal vLLM workers."""
def __init__(
self,
setup_vllm_engine_fn: SetupVllmEngineFn,
setup_kv_event_publisher_fn: SetupKvEventPublisherFn,
register_vllm_model_fn: RegisterVllmModelFn,
):
self.setup_vllm_engine = setup_vllm_engine_fn
self.setup_kv_event_publisher = setup_kv_event_publisher_fn
self.register_vllm_model = register_vllm_model_fn
@staticmethod
def handles(config: Config) -> bool:
"""Return True if this factory handles the given config."""
return bool(
config.multimodal_encode_worker
or config.multimodal_worker
or config.multimodal_decode_worker
)
async def create(
self,
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine: Optional[EngineSetupResult] = None,
) -> None:
"""Create the appropriate multimodal worker based on config flags."""
if config.multimodal_encode_worker:
await self._create_multimodal_encode_worker(runtime, config, shutdown_event)
elif config.multimodal_worker or config.multimodal_decode_worker:
await self._create_multimodal_worker(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
else:
raise ValueError(
"WorkerFactory.create() called but no multimodal worker type set in config"
)
async def _create_multimodal_worker(
self,
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine: Optional[EngineSetupResult] = None,
) -> None:
"""
Initialize multimodal worker component.
Supports:
- --multimodal-worker: PD worker that may receive embeddings from encoder
- --multimodal-decode-worker: Decode-only worker
Modes:
- Aggregated (P+D): Prefill and decode on same worker
- Disaggregated (P→D): Prefill forwards to separate decode worker
"""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
engine_client,
vllm_config,
_default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = pre_created_engine
else:
(
engine_client,
vllm_config,
_default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = self.setup_vllm_engine(config)
# Set up encode worker client when routing to encoder is enabled
encode_worker_client = None
if config.route_to_encoder:
encode_worker_client = (
await runtime.namespace(config.namespace)
.component("encoder")
.endpoint("generate")
.client()
)
logger.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
logger.info("Connected to encoder workers")
# Set up decode worker client for disaggregated mode
decode_worker_client = None
if config.is_prefill_worker:
decode_worker_client = (
await runtime.namespace(config.namespace)
.component("decoder")
.endpoint("generate")
.client()
)
await decode_worker_client.wait_for_instances()
logger.info("Connected to decode worker for disaggregated mode")
# Choose handler based on worker type
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config, shutdown_event
)
else:
handler = MultimodalPDWorkerHandler(
runtime,
component,
engine_client,
config,
encode_worker_client,
decode_worker_client,
shutdown_event,
)
handler.add_temp_dir(prometheus_temp_dir)
await handler.async_init(runtime)
# Set up KV event publisher for prefix caching if enabled
kv_publisher = self.setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
# Register model with the frontend so it can route requests
model_type = parse_endpoint_types(config.endpoint_types)
model_input = (
ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
)
await self.register_vllm_model(
model_input,
model_type,
generate_endpoint,
config,
engine_client,
vllm_config,
)
metrics_labels = [("model", config.served_model_name or config.model)]
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=metrics_labels,
),
)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def _create_multimodal_encode_worker(
self,
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
) -> None:
"""Initialize standalone multimodal encode worker."""
component = runtime.namespace(config.namespace).component(config.component)
generate_endpoint = component.endpoint(config.endpoint)
handler = EncodeWorkerHandler(config.engine_args)
await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...")
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, metrics_labels=[("model", config.model)]
),
)
except Exception as e:
logger.error(f"Failed to serve encode worker endpoint: {e}")
raise
finally:
handler.cleanup()
......@@ -49,7 +49,6 @@ vLLM supports all multimodal deployment patterns. See [Architecture Patterns](RE
| PD Worker | `--multimodal-worker` | Prefill + Decode |
| Prefill Worker | `--multimodal-worker --is-prefill-worker` | Prefill only |
| Decode Worker | `--multimodal-decode-worker` | Decode only |
| Encode+Prefill Worker | `--multimodal-encode-prefill-worker --is-prefill-worker` | Combined (Llama 4) |
## Use the Latest Release
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="llava-hf/llava-1.5-7b-hf"
EC_CONNECTOR_BACKEND="DynamoEcConnector"
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Aggregated multimodal serving with ECConnector (ec_both mode)"
echo ""
echo "This script launches:"
echo " - Frontend server"
echo " - Aggregated multimodal worker (ec_both: produces and consumes encoder cache)"
echo ""
echo "Options:"
echo " --model <model_name> Specify the VLM model to use (default: $MODEL_NAME)"
echo " -h, --help Show this help message"
echo ""
echo "Examples:"
echo " $0"
echo " $0 --model llava-hf/llava-1.5-7b-hf"
echo ""
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
echo "=================================================="
echo "Aggregated Multimodal Serving (ECConnector ec_both)"
echo "=================================================="
echo "Model: $MODEL_NAME"
echo "ECConnector Backend: $EC_CONNECTOR_BACKEND"
echo "=================================================="
# GPU assignment (override via environment variable)
DYN_WORKER_GPU=${DYN_WORKER_GPU:-0}
# GPU memory utilization
DYN_GPU_MEM=${DYN_GPU_MEM:-0.85}
# Start frontend
echo "Starting frontend..."
python -m dynamo.frontend &
# Start aggregated multimodal worker (ec_both: produces and consumes encoder cache)
echo "Starting aggregated multimodal worker (ec_both) on GPU $DYN_WORKER_GPU (mem: $DYN_GPU_MEM)..."
CUDA_VISIBLE_DEVICES=$DYN_WORKER_GPU python -m dynamo.vllm \
--multimodal-worker \
--enable-multimodal \
--model $MODEL_NAME \
--enable-mm-embeds \
--connector none \
--enforce-eager \
--gpu-memory-utilization $DYN_GPU_MEM \
--ec-transfer-config "{\"ec_connector\":\"$EC_CONNECTOR_BACKEND\",\"ec_role\":\"ec_both\"}" &
# Wait for all background processes to complete
wait
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment