Unverified Commit 7893f268 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: add --disaggregation-mode enum to vLLM backend (#6483)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 6d3e0137
...@@ -186,10 +186,10 @@ if [ "$USE_MOCKERS" = true ]; then ...@@ -186,10 +186,10 @@ if [ "$USE_MOCKERS" = true ]; then
# Set endpoint based on worker mode # Set endpoint based on worker mode
if [ "$MODE" = "prefill" ]; then if [ "$MODE" = "prefill" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.prefill.generate") MOCKER_ARGS+=("--endpoint" "dyn://test.prefill.generate")
MOCKER_ARGS+=("--is-prefill-worker") MOCKER_ARGS+=("--disaggregation-mode" "prefill")
elif [ "$MODE" = "decode" ]; then elif [ "$MODE" = "decode" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate") MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
MOCKER_ARGS+=("--is-decode-worker") MOCKER_ARGS+=("--disaggregation-mode" "decode")
else else
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate") MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
fi fi
...@@ -254,9 +254,9 @@ else ...@@ -254,9 +254,9 @@ else
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE") VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi fi
if [ "$MODE" = "prefill" ]; then if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker") VLLM_ARGS+=("--disaggregation-mode" "prefill")
elif [ "$MODE" = "decode" ]; then elif [ "$MODE" = "decode" ]; then
VLLM_ARGS+=("--is-decode-worker") VLLM_ARGS+=("--disaggregation-mode" "decode")
fi fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}") VLLM_ARGS+=("${EXTRA_ARGS[@]}")
......
...@@ -12,7 +12,7 @@ Main submodules: ...@@ -12,7 +12,7 @@ Main submodules:
- utils: Common utilities including environment and prometheus helpers - utils: Common utilities including environment and prometheus helpers
""" """
from dynamo.common import config_dump, utils from dynamo.common import config_dump, constants, utils
try: try:
from ._version import __version__ from ._version import __version__
...@@ -24,4 +24,4 @@ except Exception: ...@@ -24,4 +24,4 @@ except Exception:
except Exception: except Exception:
__version__ = "0.0.0+unknown" __version__ = "0.0.0+unknown"
__all__ = ["__version__", "config_dump", "utils"] __all__ = ["__version__", "config_dump", "constants", "utils"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared constants for Dynamo backends."""
from enum import Enum
class DisaggregationMode(Enum):
"""Disaggregation mode for LLM workers."""
AGGREGATED = "agg"
PREFILL = "prefill"
DECODE = "decode"
...@@ -26,7 +26,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume ...@@ -26,7 +26,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume
- `--data-parallel-size`: Number of data parallel workers to simulate (default: 1) - `--data-parallel-size`: Number of data parallel workers to simulate (default: 1)
- `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool - `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool
- `--stagger-delay`: Delay in seconds between launching each worker to avoid overwhelming etcd/NATS/frontend. Set to 0 to disable staggering. Use -1 for auto mode (stagger dependent on number of workers). Default: -1 (auto) - `--stagger-delay`: Delay in seconds between launching each worker to avoid overwhelming etcd/NATS/frontend. Set to 0 to disable staggering. Use -1 for auto mode (stagger dependent on number of workers). Default: -1 (auto)
- `--is-prefill-worker` / `--is-decode-worker`: Whether the worker is a prefill or decode worker for disaggregated deployment. If not specified, mocker will be in aggregated mode. - `--disaggregation-mode prefill` / `--disaggregation-mode decode`: Whether the worker is a prefill or decode worker for disaggregated deployment. If not specified, mocker will be in aggregated mode.
**Environment variables:** **Environment variables:**
......
...@@ -142,14 +142,48 @@ def create_temp_engine_args_file(args) -> Path: ...@@ -142,14 +142,48 @@ def create_temp_engine_args_file(args) -> Path:
def validate_worker_type_args(args): def validate_worker_type_args(args):
""" """
Validate that is_prefill_worker and is_decode_worker are not both True. Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails. Raises ValueError if validation fails.
""" """
import warnings
explicit_mode = args.disaggregation_mode is not None
has_legacy = args.is_prefill_worker or args.is_decode_worker
if has_legacy and explicit_mode:
raise ValueError(
"Cannot combine --is-prefill-worker/--is-decode-worker with "
"--disaggregation-mode. Use only --disaggregation-mode."
)
if has_legacy:
if args.is_prefill_worker and args.is_decode_worker: if args.is_prefill_worker and args.is_decode_worker:
raise ValueError( raise ValueError(
"Cannot specify both --is-prefill-worker and --is-decode-worker. " "Cannot specify both --is-prefill-worker and --is-decode-worker. "
"A worker must be either prefill, decode, or aggregated (neither flag set)." "A worker must be either prefill, decode, or aggregated (neither flag set)."
) )
if args.is_prefill_worker:
warnings.warn(
"--is-prefill-worker is deprecated, use --disaggregation-mode=prefill",
DeprecationWarning,
stacklevel=2,
)
args.disaggregation_mode = "prefill"
elif args.is_decode_worker:
warnings.warn(
"--is-decode-worker is deprecated, use --disaggregation-mode=decode",
DeprecationWarning,
stacklevel=2,
)
args.disaggregation_mode = "decode"
# Apply default if neither new flag nor legacy flags were provided
if args.disaggregation_mode is None:
args.disaggregation_mode = "agg"
# Sync booleans from disaggregation_mode
args.is_prefill_worker = args.disaggregation_mode == "prefill"
args.is_decode_worker = args.disaggregation_mode == "decode"
def parse_bootstrap_ports(ports_str: str | None) -> list[int]: def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
...@@ -305,17 +339,27 @@ def parse_args(): ...@@ -305,17 +339,27 @@ def parse_args():
) )
# Worker type configuration # Worker type configuration
parser.add_argument(
"--disaggregation-mode",
type=str,
default=None,
choices=["agg", "prefill", "decode"],
help="Worker disaggregation mode: 'agg' (default, aggregated), "
"'prefill' (prefill-only worker), or 'decode' (decode-only worker).",
)
parser.add_argument( parser.add_argument(
"--is-prefill-worker", "--is-prefill-worker",
action="store_true", action="store_true",
default=False, default=False,
help="Register as Prefill model type instead of Chat+Completions (default: False)", help="DEPRECATED: use --disaggregation-mode=prefill. "
"Register as Prefill model type instead of Chat+Completions (default: False)",
) )
parser.add_argument( parser.add_argument(
"--is-decode-worker", "--is-decode-worker",
action="store_true", action="store_true",
default=False, default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)", help="DEPRECATED: use --disaggregation-mode=decode. "
"Mark this as a decode worker which does not publish KV events (default: False)",
) )
parser.add_argument( parser.add_argument(
"--durable-kv-events", "--durable-kv-events",
......
...@@ -108,7 +108,8 @@ class VllmV1ConfigModifier(BaseConfigModifier): ...@@ -108,7 +108,8 @@ class VllmV1ConfigModifier(BaseConfigModifier):
args = validate_and_get_worker_args(worker_service, backend="vllm") args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args) args = break_arguments(args)
# remove --is-prefill-worker flag # remove --disaggregation-mode and its value (or legacy --is-prefill-worker)
args = remove_valued_arguments(args, "--disaggregation-mode")
if "--is-prefill-worker" in args: if "--is-prefill-worker" in args:
args.remove("--is-prefill-worker") args.remove("--is-prefill-worker")
......
...@@ -67,7 +67,7 @@ python -m dynamo.router \ ...@@ -67,7 +67,7 @@ python -m dynamo.router \
python -m dynamo.vllm --model MODEL_NAME --block-size 64 & python -m dynamo.vllm --model MODEL_NAME --block-size 64 &
# Start prefill workers # Start prefill workers
python -m dynamo.vllm --model MODEL_NAME --block-size 64 --is-prefill-worker & python -m dynamo.vllm --model MODEL_NAME --block-size 64 --disaggregation-mode prefill &
``` ```
>[!Note] >[!Note]
......
...@@ -9,7 +9,6 @@ import socket ...@@ -9,7 +9,6 @@ import socket
import sys import sys
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, Optional from typing import Any, Dict, Generator, Optional
...@@ -20,6 +19,7 @@ from sglang.srt.server_args_config_parser import ConfigArgumentMerger ...@@ -20,6 +19,7 @@ from sglang.srt.server_args_config_parser import ConfigArgumentMerger
from dynamo.common.config_dump import register_encoder from dynamo.common.config_dump import register_encoder
from dynamo.common.configuration.groups import DynamoRuntimeConfig from dynamo.common.configuration.groups import DynamoRuntimeConfig
from dynamo.common.configuration.groups.runtime_args import DynamoRuntimeArgGroup from dynamo.common.configuration.groups.runtime_args import DynamoRuntimeArgGroup
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.runtime import parse_endpoint from dynamo.common.utils.runtime import parse_endpoint
from dynamo.llm import fetch_model from dynamo.llm import fetch_model
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -28,12 +28,6 @@ from dynamo.sglang.backend_args import DynamoSGLangArgGroup, DynamoSGLangConfig ...@@ -28,12 +28,6 @@ from dynamo.sglang.backend_args import DynamoSGLangArgGroup, DynamoSGLangConfig
configure_dynamo_logging() configure_dynamo_logging()
class DisaggregationMode(Enum):
AGGREGATED = "agg"
PREFILL = "prefill"
DECODE = "decode"
class DynamoConfig(DynamoRuntimeConfig, DynamoSGLangConfig): class DynamoConfig(DynamoRuntimeConfig, DynamoSGLangConfig):
"""Combined configuration container for SGLang server and Dynamo args.""" """Combined configuration container for SGLang server and Dynamo args."""
......
...@@ -16,6 +16,7 @@ import uvloop ...@@ -16,6 +16,7 @@ import uvloop
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.constants import DisaggregationMode
from dynamo.common.storage import get_fs from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
...@@ -23,7 +24,7 @@ from dynamo.common.utils.runtime import create_runtime ...@@ -23,7 +24,7 @@ from dynamo.common.utils.runtime import create_runtime
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args from dynamo.sglang.args import Config, parse_args
from dynamo.sglang.health_check import ( from dynamo.sglang.health_check import (
ImageDiffusionHealthCheckPayload, ImageDiffusionHealthCheckPayload,
SglangHealthCheckPayload, SglangHealthCheckPayload,
......
...@@ -9,8 +9,9 @@ from typing import Any, AsyncGenerator, Dict, Optional ...@@ -9,8 +9,9 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......
...@@ -11,8 +11,9 @@ import torch ...@@ -11,8 +11,9 @@ import torch
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo._core import Client, Context from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode from dynamo.sglang.args import Config
from dynamo.sglang.protocol import ( from dynamo.sglang.protocol import (
DisaggSglangMultimodalRequest, DisaggSglangMultimodalRequest,
SglangMultimodalRequest, SglangMultimodalRequest,
......
...@@ -24,6 +24,7 @@ from dynamo.common.configuration.groups.runtime_args import ( ...@@ -24,6 +24,7 @@ from dynamo.common.configuration.groups.runtime_args import (
) )
from dynamo.common.utils.runtime import parse_endpoint from dynamo.common.utils.runtime import parse_endpoint
from dynamo.vllm.backend_args import DynamoVllmArgGroup, DynamoVllmConfig from dynamo.vllm.backend_args import DynamoVllmArgGroup, DynamoVllmConfig
from dynamo.vllm.constants import DisaggregationMode
from . import envs from . import envs
...@@ -35,8 +36,6 @@ VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"} ...@@ -35,8 +36,6 @@ VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
class Config(DynamoRuntimeConfig, DynamoVllmConfig): class Config(DynamoRuntimeConfig, DynamoVllmConfig):
component: str component: str
is_prefill_worker: bool
is_decode_worker: bool
custom_jinja_template: Optional[str] = None custom_jinja_template: Optional[str] = None
discovery_backend: str discovery_backend: str
request_plane: str request_plane: str
...@@ -157,7 +156,6 @@ def update_dynamo_config_with_engine( ...@@ -157,7 +156,6 @@ def update_dynamo_config_with_engine(
# Capture user-provided --endpoint before defaults overwrite it # Capture user-provided --endpoint before defaults overwrite it
user_endpoint = dynamo_config.endpoint user_endpoint = dynamo_config.endpoint
# TODO: move to "disaggregation_mode" as the other engines.
if dynamo_config.route_to_encoder: if dynamo_config.route_to_encoder:
dynamo_config.component = "processor" dynamo_config.component = "processor"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
...@@ -167,13 +165,16 @@ def update_dynamo_config_with_engine( ...@@ -167,13 +165,16 @@ def update_dynamo_config_with_engine(
elif dynamo_config.multimodal_decode_worker: elif dynamo_config.multimodal_decode_worker:
dynamo_config.component = "decoder" dynamo_config.component = "decoder"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif dynamo_config.multimodal_worker and dynamo_config.is_prefill_worker: elif (
dynamo_config.multimodal_worker
and dynamo_config.disaggregation_mode == DisaggregationMode.PREFILL
):
dynamo_config.component = "backend" dynamo_config.component = "backend"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif dynamo_config.omni: elif dynamo_config.omni:
dynamo_config.component = "backend" dynamo_config.component = "backend"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif dynamo_config.is_prefill_worker: elif dynamo_config.disaggregation_mode == DisaggregationMode.PREFILL:
dynamo_config.component = "prefill" dynamo_config.component = "prefill"
dynamo_config.endpoint = "generate" dynamo_config.endpoint = "generate"
else: else:
...@@ -320,10 +321,10 @@ def create_kv_events_config( ...@@ -320,10 +321,10 @@ def create_kv_events_config(
dynamo_config: Config, engine_config: AsyncEngineArgs dynamo_config: Config, engine_config: AsyncEngineArgs
) -> Optional[KVEventsConfig]: ) -> Optional[KVEventsConfig]:
"""Create KVEventsConfig for prefix caching if needed.""" """Create KVEventsConfig for prefix caching if needed."""
if dynamo_config.is_decode_worker: if dynamo_config.disaggregation_mode == DisaggregationMode.DECODE:
logger.info( logger.info(
f"Decode worker detected (is_decode_worker={dynamo_config.is_decode_worker}): " "Decode worker detected (disaggregation_mode=decode): "
f"kv_events_config disabled (decode workers don't publish KV events)" "kv_events_config disabled (decode workers don't publish KV events)"
) )
return None return None
......
...@@ -3,13 +3,15 @@ ...@@ -3,13 +3,15 @@
"""Dynamo vLLM wrapper configuration ArgGroup.""" """Dynamo vLLM wrapper configuration ArgGroup."""
from typing import Optional import warnings
from typing import Optional, Union
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__ from . import __version__
from .constants import DisaggregationMode
class DynamoVllmArgGroup(ArgGroup): class DynamoVllmArgGroup(ArgGroup):
...@@ -25,12 +27,23 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -25,12 +27,23 @@ class DynamoVllmArgGroup(ArgGroup):
) )
g = parser.add_argument_group("Dynamo vLLM Options") g = parser.add_argument_group("Dynamo vLLM Options")
add_argument(
g,
flag_name="--disaggregation-mode",
env_var="DYN_VLLM_DISAGGREGATION_MODE",
default=None,
help="Worker disaggregation mode: 'agg' (default, aggregated), "
"'prefill' (prefill-only worker), or 'decode' (decode-only worker).",
choices=[m.value for m in DisaggregationMode],
)
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--is-prefill-worker", flag_name="--is-prefill-worker",
env_var="DYN_VLLM_IS_PREFILL_WORKER", env_var="DYN_VLLM_IS_PREFILL_WORKER",
default=False, default=False,
help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate", help="DEPRECATED: use --disaggregation-mode=prefill. "
"Enable prefill functionality for this worker.",
) )
add_negatable_bool_argument( add_negatable_bool_argument(
...@@ -38,7 +51,8 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -38,7 +51,8 @@ class DynamoVllmArgGroup(ArgGroup):
flag_name="--is-decode-worker", flag_name="--is-decode-worker",
env_var="DYN_VLLM_IS_DECODE_WORKER", env_var="DYN_VLLM_IS_DECODE_WORKER",
default=False, default=False,
help="Mark this as a decode worker which does not publish KV events", help="DEPRECATED: use --disaggregation-mode=decode. "
"Mark this as a decode worker which does not publish KV events.",
) )
add_negatable_bool_argument( add_negatable_bool_argument(
...@@ -295,6 +309,9 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -295,6 +309,9 @@ class DynamoVllmArgGroup(ArgGroup):
class DynamoVllmConfig(ConfigBase): class DynamoVllmConfig(ConfigBase):
"""Configuration for Dynamo vLLM wrapper (vLLM-specific only). All fields optional.""" """Configuration for Dynamo vLLM wrapper (vLLM-specific only). All fields optional."""
disaggregation_mode: Union[
None, str, DisaggregationMode
] # None when not provided; resolved to enum in validate()
is_prefill_worker: bool is_prefill_worker: bool
is_decode_worker: bool is_decode_worker: bool
use_vllm_tokenizer: bool use_vllm_tokenizer: bool
...@@ -344,17 +361,63 @@ class DynamoVllmConfig(ConfigBase): ...@@ -344,17 +361,63 @@ class DynamoVllmConfig(ConfigBase):
def validate(self) -> None: def validate(self) -> None:
"""Validate vLLM wrapper configuration.""" """Validate vLLM wrapper configuration."""
self._validate_prefill_decode_exclusive() self._resolve_disaggregation_mode()
self._validate_multimodal_role_exclusivity() self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag() self._validate_multimodal_requires_flag()
self._validate_omni_stage_config() self._validate_omni_stage_config()
def _validate_prefill_decode_exclusive(self) -> None: def _resolve_disaggregation_mode(self) -> None:
"""Ensure at most one of is_prefill_worker and is_decode_worker is set.""" """Resolve disaggregation_mode from new enum or legacy boolean flags.
Priority:
1. If --disaggregation-mode was explicitly provided, use it.
Raise if legacy booleans are also set.
2. If legacy --is-prefill-worker or --is-decode-worker is set,
emit DeprecationWarning and translate to enum.
3. Apply default (AGGREGATED) if nothing was provided.
4. Sync boolean fields from the resolved enum value.
"""
# Convert string to enum (non-None means explicitly provided)
explicit_mode = self.disaggregation_mode is not None
if isinstance(self.disaggregation_mode, str):
self.disaggregation_mode = DisaggregationMode(self.disaggregation_mode)
# Check for legacy boolean flags
has_legacy = self.is_prefill_worker or self.is_decode_worker
if has_legacy and explicit_mode:
raise ValueError(
"Cannot combine --is-prefill-worker/--is-decode-worker with "
"--disaggregation-mode. Use only --disaggregation-mode."
)
if has_legacy:
if self.is_prefill_worker and self.is_decode_worker: if self.is_prefill_worker and self.is_decode_worker:
raise ValueError( raise ValueError(
"Cannot set both --is-prefill-worker and --is-decode-worker" "Cannot set both --is-prefill-worker and --is-decode-worker"
) )
if self.is_prefill_worker:
warnings.warn(
"--is-prefill-worker is deprecated, use --disaggregation-mode=prefill",
DeprecationWarning,
stacklevel=2,
)
self.disaggregation_mode = DisaggregationMode.PREFILL
elif self.is_decode_worker:
warnings.warn(
"--is-decode-worker is deprecated, use --disaggregation-mode=decode",
DeprecationWarning,
stacklevel=2,
)
self.disaggregation_mode = DisaggregationMode.DECODE
# Apply default if neither new flag nor legacy flags were provided
if self.disaggregation_mode is None:
self.disaggregation_mode = DisaggregationMode.AGGREGATED
# Sync booleans from enum (canonical source of truth)
self.is_prefill_worker = self.disaggregation_mode == DisaggregationMode.PREFILL
self.is_decode_worker = self.disaggregation_mode == DisaggregationMode.DECODE
def _count_multimodal_roles(self) -> int: def _count_multimodal_roles(self) -> int:
"""Return the number of multimodal worker roles set (0 or 1 allowed). """Return the number of multimodal worker roles set (0 or 1 allowed).
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Constants for vLLM backend.
DisaggregationMode is defined in dynamo.common.constants and re-exported here
so that existing imports from dynamo.vllm.constants continue to work.
"""
from dynamo.common.constants import DisaggregationMode
__all__ = ["DisaggregationMode"]
...@@ -54,6 +54,7 @@ from dynamo.vllm.worker_factory import WorkerFactory ...@@ -54,6 +54,7 @@ from dynamo.vllm.worker_factory import WorkerFactory
from .args import Config, parse_args from .args import Config, parse_args
from .checkpoint_restore import get_checkpoint_config from .checkpoint_restore import get_checkpoint_config
from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import ( from .health_check import (
VllmHealthCheckPayload, VllmHealthCheckPayload,
...@@ -190,7 +191,7 @@ async def worker(): ...@@ -190,7 +191,7 @@ async def worker():
elif config.omni: elif config.omni:
await init_omni(runtime, config, shutdown_event) await init_omni(runtime, config, shutdown_event)
logger.debug("init_omni completed") logger.debug("init_omni completed")
elif config.is_prefill_worker: elif config.disaggregation_mode == DisaggregationMode.PREFILL:
await init_prefill( await init_prefill(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine runtime, config, shutdown_event, pre_created_engine=pre_created_engine
) )
...@@ -318,7 +319,7 @@ def setup_kv_event_publisher( ...@@ -318,7 +319,7 @@ def setup_kv_event_publisher(
return None return None
# Skip KV event publishing for decode workers # Skip KV event publishing for decode workers
if config.is_decode_worker: if config.disaggregation_mode == DisaggregationMode.DECODE:
logger.info("Skipping KV event publisher setup for decode worker") logger.info("Skipping KV event publisher setup for decode worker")
return None return None
...@@ -516,7 +517,8 @@ async def register_vllm_model( ...@@ -516,7 +517,8 @@ async def register_vllm_model(
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"] runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config.enable_local_indexer = ( runtime_config.enable_local_indexer = (
config.enable_local_indexer and not config.is_decode_worker config.enable_local_indexer
and config.disaggregation_mode != DisaggregationMode.DECODE
) )
# Add tool/reasoning parsers for decode models # Add tool/reasoning parsers for decode models
......
...@@ -23,6 +23,7 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -23,6 +23,7 @@ from dynamo.common.multimodal.embedding_transfer import (
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from ..args import Config from ..args import Config
from ..constants import DisaggregationMode
from ..handlers import BaseWorkerHandler, build_sampling_params from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import ( from ..multimodal_utils import (
MyRequestOutput, MyRequestOutput,
...@@ -70,7 +71,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -70,7 +71,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self.config = config self.config = config
self.encode_worker_client = encode_worker_client self.encode_worker_client = encode_worker_client
self.decode_worker_client = decode_worker_client self.decode_worker_client = decode_worker_client
self.enable_disagg = config.is_prefill_worker self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0: if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int( capacity_bytes = int(
......
...@@ -9,6 +9,7 @@ import dynamo.nixl_connect as connect ...@@ -9,6 +9,7 @@ import dynamo.nixl_connect as connect
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..args import Config from ..args import Config
from ..constants import DisaggregationMode
from ..handlers import BaseWorkerHandler from ..handlers import BaseWorkerHandler
from ..multimodal_utils import MyRequestOutput, vLLMMultimodalRequest from ..multimodal_utils import MyRequestOutput, vLLMMultimodalRequest
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
...@@ -44,7 +45,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -44,7 +45,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
) )
self.config = config self.config = config
self.enable_disagg = config.is_prefill_worker self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
async def async_init(self, runtime: DistributedRuntime): async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup""" """Async initialization - connector needs async setup"""
......
...@@ -38,9 +38,16 @@ def _make_config( ...@@ -38,9 +38,16 @@ def _make_config(
multimodal_embedding_cache_capacity_gb: float = 0, multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock: ) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler.""" """Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
from dynamo.vllm.constants import DisaggregationMode
config = MagicMock() config = MagicMock()
config.model = model config.model = model
config.is_prefill_worker = is_prefill_worker config.is_prefill_worker = is_prefill_worker
config.disaggregation_mode = (
DisaggregationMode.PREFILL
if is_prefill_worker
else DisaggregationMode.AGGREGATED
)
config.enable_multimodal = enable_multimodal config.enable_multimodal = enable_multimodal
config.multimodal_embedding_cache_capacity_gb = ( config.multimodal_embedding_cache_capacity_gb = (
multimodal_embedding_cache_capacity_gb multimodal_embedding_cache_capacity_gb
......
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
"""Unit tests for vLLM backend components.""" """Unit tests for vLLM backend components."""
import re import re
import warnings
from pathlib import Path from pathlib import Path
import pytest import pytest
from dynamo.vllm.args import parse_args from dynamo.vllm.args import parse_args
from dynamo.vllm.constants import DisaggregationMode
from dynamo.vllm.tests.conftest import make_cli_args_fixture from dynamo.vllm.tests.conftest import make_cli_args_fixture
# Get path relative to this test file # Get path relative to this test file
...@@ -169,13 +171,14 @@ def test_endpoint_not_provided_preserves_defaults(mock_vllm_cli): ...@@ -169,13 +171,14 @@ def test_endpoint_not_provided_preserves_defaults(mock_vllm_cli):
def test_endpoint_overrides_with_prefill_worker(mock_vllm_cli): def test_endpoint_overrides_with_prefill_worker(mock_vllm_cli):
"""Test that --endpoint overrides even with --is-prefill-worker.""" """Test that --endpoint overrides even with --disaggregation-mode prefill."""
mock_vllm_cli( mock_vllm_cli(
"--model", "--model",
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",
"--endpoint", "--endpoint",
"dyn://custom.worker.serve", "dyn://custom.worker.serve",
"--is-prefill-worker", "--disaggregation-mode",
"prefill",
) )
config = parse_args() config = parse_args()
assert config.namespace == "custom" assert config.namespace == "custom"
...@@ -216,3 +219,86 @@ def test_headless_namespace_has_required_fields(mock_vllm_cli): ...@@ -216,3 +219,86 @@ def test_headless_namespace_has_required_fields(mock_vllm_cli):
# Core engine fields must survive the round-trip # Core engine fields must survive the round-trip
assert hasattr(ns, "model") assert hasattr(ns, "model")
assert hasattr(ns, "tensor_parallel_size") assert hasattr(ns, "tensor_parallel_size")
# --disaggregation-mode tests
def test_disaggregation_mode_default(mock_vllm_cli):
"""Test that default disaggregation mode is AGGREGATED."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.AGGREGATED
assert config.is_prefill_worker is False
assert config.is_decode_worker is False
def test_disaggregation_mode_prefill(mock_vllm_cli):
"""Test --disaggregation-mode prefill sets correct state."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--disaggregation-mode", "prefill")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.PREFILL
assert config.is_prefill_worker is True
assert config.is_decode_worker is False
assert config.component == "prefill"
def test_disaggregation_mode_decode(mock_vllm_cli):
"""Test --disaggregation-mode decode sets correct state."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--disaggregation-mode", "decode")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.DECODE
assert config.is_prefill_worker is False
assert config.is_decode_worker is True
def test_legacy_is_prefill_worker_emits_deprecation(mock_vllm_cli):
"""Test that --is-prefill-worker still works but emits DeprecationWarning."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--is-prefill-worker")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config = parse_args()
deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert len(deprecation_warnings) >= 1
assert "deprecated" in str(deprecation_warnings[0].message).lower()
assert config.disaggregation_mode == DisaggregationMode.PREFILL
assert config.is_prefill_worker is True
def test_legacy_is_decode_worker_emits_deprecation(mock_vllm_cli):
"""Test that --is-decode-worker still works but emits DeprecationWarning."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--is-decode-worker")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config = parse_args()
deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert len(deprecation_warnings) >= 1
assert "deprecated" in str(deprecation_warnings[0].message).lower()
assert config.disaggregation_mode == DisaggregationMode.DECODE
assert config.is_decode_worker is True
def test_conflicting_legacy_and_new_flags_raises(mock_vllm_cli):
"""Test that combining legacy flags with explicit --disaggregation-mode raises ValueError."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--disaggregation-mode",
"prefill",
"--is-decode-worker",
)
with pytest.raises(ValueError, match="Cannot combine"):
parse_args()
def test_explicit_default_mode_with_legacy_flag_raises(mock_vllm_cli):
"""Test that --disaggregation-mode agg --is-decode-worker raises ValueError."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--disaggregation-mode",
"agg",
"--is-decode-worker",
)
with pytest.raises(ValueError, match="Cannot combine"):
parse_args()
...@@ -13,6 +13,7 @@ from dynamo.llm import ModelInput ...@@ -13,6 +13,7 @@ from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from .args import Config from .args import Config
from .constants import DisaggregationMode
from .multimodal_handlers import ( from .multimodal_handlers import (
EncodeWorkerHandler, EncodeWorkerHandler,
MultimodalDecodeWorkerHandler, MultimodalDecodeWorkerHandler,
...@@ -149,7 +150,7 @@ class WorkerFactory: ...@@ -149,7 +150,7 @@ class WorkerFactory:
# Set up decode worker client for disaggregated mode # Set up decode worker client for disaggregated mode
decode_worker_client = None decode_worker_client = None
if config.is_prefill_worker: if config.disaggregation_mode == DisaggregationMode.PREFILL:
decode_worker_client = await runtime.endpoint( decode_worker_client = await runtime.endpoint(
f"{config.namespace}.decoder.generate" f"{config.namespace}.decoder.generate"
).client() ).client()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment