Unverified Commit 7893f268 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: add --disaggregation-mode enum to vLLM backend (#6483)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 6d3e0137
......@@ -186,10 +186,10 @@ if [ "$USE_MOCKERS" = true ]; then
# Set endpoint based on worker mode
if [ "$MODE" = "prefill" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.prefill.generate")
MOCKER_ARGS+=("--is-prefill-worker")
MOCKER_ARGS+=("--disaggregation-mode" "prefill")
elif [ "$MODE" = "decode" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
MOCKER_ARGS+=("--is-decode-worker")
MOCKER_ARGS+=("--disaggregation-mode" "decode")
else
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
fi
......@@ -254,9 +254,9 @@ else
VLLM_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
if [ "$MODE" = "prefill" ]; then
VLLM_ARGS+=("--is-prefill-worker")
VLLM_ARGS+=("--disaggregation-mode" "prefill")
elif [ "$MODE" = "decode" ]; then
VLLM_ARGS+=("--is-decode-worker")
VLLM_ARGS+=("--disaggregation-mode" "decode")
fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
......
......@@ -12,7 +12,7 @@ Main submodules:
- utils: Common utilities including environment and prometheus helpers
"""
from dynamo.common import config_dump, utils
from dynamo.common import config_dump, constants, utils
try:
from ._version import __version__
......@@ -24,4 +24,4 @@ except Exception:
except Exception:
__version__ = "0.0.0+unknown"
__all__ = ["__version__", "config_dump", "utils"]
__all__ = ["__version__", "config_dump", "constants", "utils"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Shared constants for Dynamo backends."""
from enum import Enum
class DisaggregationMode(Enum):
"""Disaggregation mode for LLM workers."""
AGGREGATED = "agg"
PREFILL = "prefill"
DECODE = "decode"
......@@ -26,7 +26,7 @@ The mocker engine now supports a vLLM-style CLI interface with individual argume
- `--data-parallel-size`: Number of data parallel workers to simulate (default: 1)
- `--num-workers`: Number of mocker workers to launch in the same process (default: 1). All workers share the same tokio runtime and thread pool
- `--stagger-delay`: Delay in seconds between launching each worker to avoid overwhelming etcd/NATS/frontend. Set to 0 to disable staggering. Use -1 for auto mode (stagger dependent on number of workers). Default: -1 (auto)
- `--is-prefill-worker` / `--is-decode-worker`: Whether the worker is a prefill or decode worker for disaggregated deployment. If not specified, mocker will be in aggregated mode.
- `--disaggregation-mode prefill` / `--disaggregation-mode decode`: Whether the worker is a prefill or decode worker for disaggregated deployment. If not specified, mocker will be in aggregated mode.
**Environment variables:**
......
......@@ -142,15 +142,49 @@ def create_temp_engine_args_file(args) -> Path:
def validate_worker_type_args(args):
"""
Validate that is_prefill_worker and is_decode_worker are not both True.
Resolve disaggregation mode from --disaggregation-mode or legacy boolean flags.
Raises ValueError if validation fails.
"""
if args.is_prefill_worker and args.is_decode_worker:
import warnings
explicit_mode = args.disaggregation_mode is not None
has_legacy = args.is_prefill_worker or args.is_decode_worker
if has_legacy and explicit_mode:
raise ValueError(
"Cannot specify both --is-prefill-worker and --is-decode-worker. "
"A worker must be either prefill, decode, or aggregated (neither flag set)."
"Cannot combine --is-prefill-worker/--is-decode-worker with "
"--disaggregation-mode. Use only --disaggregation-mode."
)
if has_legacy:
if args.is_prefill_worker and args.is_decode_worker:
raise ValueError(
"Cannot specify both --is-prefill-worker and --is-decode-worker. "
"A worker must be either prefill, decode, or aggregated (neither flag set)."
)
if args.is_prefill_worker:
warnings.warn(
"--is-prefill-worker is deprecated, use --disaggregation-mode=prefill",
DeprecationWarning,
stacklevel=2,
)
args.disaggregation_mode = "prefill"
elif args.is_decode_worker:
warnings.warn(
"--is-decode-worker is deprecated, use --disaggregation-mode=decode",
DeprecationWarning,
stacklevel=2,
)
args.disaggregation_mode = "decode"
# Apply default if neither new flag nor legacy flags were provided
if args.disaggregation_mode is None:
args.disaggregation_mode = "agg"
# Sync booleans from disaggregation_mode
args.is_prefill_worker = args.disaggregation_mode == "prefill"
args.is_decode_worker = args.disaggregation_mode == "decode"
def parse_bootstrap_ports(ports_str: str | None) -> list[int]:
"""Parse comma-separated bootstrap ports string into list of integers."""
......@@ -305,17 +339,27 @@ def parse_args():
)
# Worker type configuration
parser.add_argument(
"--disaggregation-mode",
type=str,
default=None,
choices=["agg", "prefill", "decode"],
help="Worker disaggregation mode: 'agg' (default, aggregated), "
"'prefill' (prefill-only worker), or 'decode' (decode-only worker).",
)
parser.add_argument(
"--is-prefill-worker",
action="store_true",
default=False,
help="Register as Prefill model type instead of Chat+Completions (default: False)",
help="DEPRECATED: use --disaggregation-mode=prefill. "
"Register as Prefill model type instead of Chat+Completions (default: False)",
)
parser.add_argument(
"--is-decode-worker",
action="store_true",
default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)",
help="DEPRECATED: use --disaggregation-mode=decode. "
"Mark this as a decode worker which does not publish KV events (default: False)",
)
parser.add_argument(
"--durable-kv-events",
......
......@@ -108,7 +108,8 @@ class VllmV1ConfigModifier(BaseConfigModifier):
args = validate_and_get_worker_args(worker_service, backend="vllm")
args = break_arguments(args)
# remove --is-prefill-worker flag
# remove --disaggregation-mode and its value (or legacy --is-prefill-worker)
args = remove_valued_arguments(args, "--disaggregation-mode")
if "--is-prefill-worker" in args:
args.remove("--is-prefill-worker")
......
......@@ -67,7 +67,7 @@ python -m dynamo.router \
python -m dynamo.vllm --model MODEL_NAME --block-size 64 &
# Start prefill workers
python -m dynamo.vllm --model MODEL_NAME --block-size 64 --is-prefill-worker &
python -m dynamo.vllm --model MODEL_NAME --block-size 64 --disaggregation-mode prefill &
```
>[!Note]
......
......@@ -9,7 +9,6 @@ import socket
import sys
import tempfile
from argparse import Namespace
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, Optional
......@@ -20,6 +19,7 @@ from sglang.srt.server_args_config_parser import ConfigArgumentMerger
from dynamo.common.config_dump import register_encoder
from dynamo.common.configuration.groups import DynamoRuntimeConfig
from dynamo.common.configuration.groups.runtime_args import DynamoRuntimeArgGroup
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.llm import fetch_model
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -28,12 +28,6 @@ from dynamo.sglang.backend_args import DynamoSGLangArgGroup, DynamoSGLangConfig
configure_dynamo_logging()
class DisaggregationMode(Enum):
AGGREGATED = "agg"
PREFILL = "prefill"
DECODE = "decode"
class DynamoConfig(DynamoRuntimeConfig, DynamoSGLangConfig):
"""Combined configuration container for SGLang server and Dynamo args."""
......
......@@ -16,6 +16,7 @@ import uvloop
from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.constants import DisaggregationMode
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
......@@ -23,7 +24,7 @@ from dynamo.common.utils.runtime import create_runtime
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.args import Config, parse_args
from dynamo.sglang.health_check import (
ImageDiffusionHealthCheckPayload,
SglangHealthCheckPayload,
......
......@@ -9,8 +9,9 @@ from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......
......@@ -11,8 +11,9 @@ import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import (
DisaggSglangMultimodalRequest,
SglangMultimodalRequest,
......
......@@ -24,6 +24,7 @@ from dynamo.common.configuration.groups.runtime_args import (
)
from dynamo.common.utils.runtime import parse_endpoint
from dynamo.vllm.backend_args import DynamoVllmArgGroup, DynamoVllmConfig
from dynamo.vllm.constants import DisaggregationMode
from . import envs
......@@ -35,8 +36,6 @@ VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
class Config(DynamoRuntimeConfig, DynamoVllmConfig):
component: str
is_prefill_worker: bool
is_decode_worker: bool
custom_jinja_template: Optional[str] = None
discovery_backend: str
request_plane: str
......@@ -157,7 +156,6 @@ def update_dynamo_config_with_engine(
# Capture user-provided --endpoint before defaults overwrite it
user_endpoint = dynamo_config.endpoint
# TODO: move to "disaggregation_mode" as the other engines.
if dynamo_config.route_to_encoder:
dynamo_config.component = "processor"
dynamo_config.endpoint = "generate"
......@@ -167,13 +165,16 @@ def update_dynamo_config_with_engine(
elif dynamo_config.multimodal_decode_worker:
dynamo_config.component = "decoder"
dynamo_config.endpoint = "generate"
elif dynamo_config.multimodal_worker and dynamo_config.is_prefill_worker:
elif (
dynamo_config.multimodal_worker
and dynamo_config.disaggregation_mode == DisaggregationMode.PREFILL
):
dynamo_config.component = "backend"
dynamo_config.endpoint = "generate"
elif dynamo_config.omni:
dynamo_config.component = "backend"
dynamo_config.endpoint = "generate"
elif dynamo_config.is_prefill_worker:
elif dynamo_config.disaggregation_mode == DisaggregationMode.PREFILL:
dynamo_config.component = "prefill"
dynamo_config.endpoint = "generate"
else:
......@@ -320,10 +321,10 @@ def create_kv_events_config(
dynamo_config: Config, engine_config: AsyncEngineArgs
) -> Optional[KVEventsConfig]:
"""Create KVEventsConfig for prefix caching if needed."""
if dynamo_config.is_decode_worker:
if dynamo_config.disaggregation_mode == DisaggregationMode.DECODE:
logger.info(
f"Decode worker detected (is_decode_worker={dynamo_config.is_decode_worker}): "
f"kv_events_config disabled (decode workers don't publish KV events)"
"Decode worker detected (disaggregation_mode=decode): "
"kv_events_config disabled (decode workers don't publish KV events)"
)
return None
......
......@@ -3,13 +3,15 @@
"""Dynamo vLLM wrapper configuration ArgGroup."""
from typing import Optional
import warnings
from typing import Optional, Union
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__
from .constants import DisaggregationMode
class DynamoVllmArgGroup(ArgGroup):
......@@ -25,12 +27,23 @@ class DynamoVllmArgGroup(ArgGroup):
)
g = parser.add_argument_group("Dynamo vLLM Options")
add_argument(
g,
flag_name="--disaggregation-mode",
env_var="DYN_VLLM_DISAGGREGATION_MODE",
default=None,
help="Worker disaggregation mode: 'agg' (default, aggregated), "
"'prefill' (prefill-only worker), or 'decode' (decode-only worker).",
choices=[m.value for m in DisaggregationMode],
)
add_negatable_bool_argument(
g,
flag_name="--is-prefill-worker",
env_var="DYN_VLLM_IS_PREFILL_WORKER",
default=False,
help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate",
help="DEPRECATED: use --disaggregation-mode=prefill. "
"Enable prefill functionality for this worker.",
)
add_negatable_bool_argument(
......@@ -38,7 +51,8 @@ class DynamoVllmArgGroup(ArgGroup):
flag_name="--is-decode-worker",
env_var="DYN_VLLM_IS_DECODE_WORKER",
default=False,
help="Mark this as a decode worker which does not publish KV events",
help="DEPRECATED: use --disaggregation-mode=decode. "
"Mark this as a decode worker which does not publish KV events.",
)
add_negatable_bool_argument(
......@@ -295,6 +309,9 @@ class DynamoVllmArgGroup(ArgGroup):
class DynamoVllmConfig(ConfigBase):
"""Configuration for Dynamo vLLM wrapper (vLLM-specific only). All fields optional."""
disaggregation_mode: Union[
None, str, DisaggregationMode
] # None when not provided; resolved to enum in validate()
is_prefill_worker: bool
is_decode_worker: bool
use_vllm_tokenizer: bool
......@@ -344,18 +361,64 @@ class DynamoVllmConfig(ConfigBase):
def validate(self) -> None:
"""Validate vLLM wrapper configuration."""
self._validate_prefill_decode_exclusive()
self._resolve_disaggregation_mode()
self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag()
self._validate_omni_stage_config()
def _validate_prefill_decode_exclusive(self) -> None:
"""Ensure at most one of is_prefill_worker and is_decode_worker is set."""
if self.is_prefill_worker and self.is_decode_worker:
def _resolve_disaggregation_mode(self) -> None:
"""Resolve disaggregation_mode from new enum or legacy boolean flags.
Priority:
1. If --disaggregation-mode was explicitly provided, use it.
Raise if legacy booleans are also set.
2. If legacy --is-prefill-worker or --is-decode-worker is set,
emit DeprecationWarning and translate to enum.
3. Apply default (AGGREGATED) if nothing was provided.
4. Sync boolean fields from the resolved enum value.
"""
# Convert string to enum (non-None means explicitly provided)
explicit_mode = self.disaggregation_mode is not None
if isinstance(self.disaggregation_mode, str):
self.disaggregation_mode = DisaggregationMode(self.disaggregation_mode)
# Check for legacy boolean flags
has_legacy = self.is_prefill_worker or self.is_decode_worker
if has_legacy and explicit_mode:
raise ValueError(
"Cannot set both --is-prefill-worker and --is-decode-worker"
"Cannot combine --is-prefill-worker/--is-decode-worker with "
"--disaggregation-mode. Use only --disaggregation-mode."
)
if has_legacy:
if self.is_prefill_worker and self.is_decode_worker:
raise ValueError(
"Cannot set both --is-prefill-worker and --is-decode-worker"
)
if self.is_prefill_worker:
warnings.warn(
"--is-prefill-worker is deprecated, use --disaggregation-mode=prefill",
DeprecationWarning,
stacklevel=2,
)
self.disaggregation_mode = DisaggregationMode.PREFILL
elif self.is_decode_worker:
warnings.warn(
"--is-decode-worker is deprecated, use --disaggregation-mode=decode",
DeprecationWarning,
stacklevel=2,
)
self.disaggregation_mode = DisaggregationMode.DECODE
# Apply default if neither new flag nor legacy flags were provided
if self.disaggregation_mode is None:
self.disaggregation_mode = DisaggregationMode.AGGREGATED
# Sync booleans from enum (canonical source of truth)
self.is_prefill_worker = self.disaggregation_mode == DisaggregationMode.PREFILL
self.is_decode_worker = self.disaggregation_mode == DisaggregationMode.DECODE
def _count_multimodal_roles(self) -> int:
"""Return the number of multimodal worker roles set (0 or 1 allowed).
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Constants for vLLM backend.
DisaggregationMode is defined in dynamo.common.constants and re-exported here
so that existing imports from dynamo.vllm.constants continue to work.
"""
from dynamo.common.constants import DisaggregationMode
__all__ = ["DisaggregationMode"]
......@@ -54,6 +54,7 @@ from dynamo.vllm.worker_factory import WorkerFactory
from .args import Config, parse_args
from .checkpoint_restore import get_checkpoint_config
from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import (
VllmHealthCheckPayload,
......@@ -190,7 +191,7 @@ async def worker():
elif config.omni:
await init_omni(runtime, config, shutdown_event)
logger.debug("init_omni completed")
elif config.is_prefill_worker:
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
await init_prefill(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
......@@ -318,7 +319,7 @@ def setup_kv_event_publisher(
return None
# Skip KV event publishing for decode workers
if config.is_decode_worker:
if config.disaggregation_mode == DisaggregationMode.DECODE:
logger.info("Skipping KV event publisher setup for decode worker")
return None
......@@ -516,7 +517,8 @@ async def register_vllm_model(
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config.enable_local_indexer = (
config.enable_local_indexer and not config.is_decode_worker
config.enable_local_indexer
and config.disaggregation_mode != DisaggregationMode.DECODE
)
# Add tool/reasoning parsers for decode models
......
......@@ -23,6 +23,7 @@ from dynamo.common.multimodal.embedding_transfer import (
from dynamo.runtime import Client, DistributedRuntime
from ..args import Config
from ..constants import DisaggregationMode
from ..handlers import BaseWorkerHandler, build_sampling_params
from ..multimodal_utils import (
MyRequestOutput,
......@@ -70,7 +71,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
self.config = config
self.encode_worker_client = encode_worker_client
self.decode_worker_client = decode_worker_client
self.enable_disagg = config.is_prefill_worker
self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
self.embedding_cache_manager: MultimodalEmbeddingCacheManager | None = None
if config.multimodal_embedding_cache_capacity_gb > 0:
capacity_bytes = int(
......
......@@ -9,6 +9,7 @@ import dynamo.nixl_connect as connect
from dynamo.runtime import DistributedRuntime
from ..args import Config
from ..constants import DisaggregationMode
from ..handlers import BaseWorkerHandler
from ..multimodal_utils import MyRequestOutput, vLLMMultimodalRequest
from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_model
......@@ -44,7 +45,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
)
self.config = config
self.enable_disagg = config.is_prefill_worker
self.enable_disagg = config.disaggregation_mode == DisaggregationMode.PREFILL
async def async_init(self, runtime: DistributedRuntime):
"""Async initialization - connector needs async setup"""
......
......@@ -38,9 +38,16 @@ def _make_config(
multimodal_embedding_cache_capacity_gb: float = 0,
) -> MagicMock:
"""Create a mock Config with the fields used by MultimodalPDWorkerHandler."""
from dynamo.vllm.constants import DisaggregationMode
config = MagicMock()
config.model = model
config.is_prefill_worker = is_prefill_worker
config.disaggregation_mode = (
DisaggregationMode.PREFILL
if is_prefill_worker
else DisaggregationMode.AGGREGATED
)
config.enable_multimodal = enable_multimodal
config.multimodal_embedding_cache_capacity_gb = (
multimodal_embedding_cache_capacity_gb
......
......@@ -4,11 +4,13 @@
"""Unit tests for vLLM backend components."""
import re
import warnings
from pathlib import Path
import pytest
from dynamo.vllm.args import parse_args
from dynamo.vllm.constants import DisaggregationMode
from dynamo.vllm.tests.conftest import make_cli_args_fixture
# Get path relative to this test file
......@@ -169,13 +171,14 @@ def test_endpoint_not_provided_preserves_defaults(mock_vllm_cli):
def test_endpoint_overrides_with_prefill_worker(mock_vllm_cli):
"""Test that --endpoint overrides even with --is-prefill-worker."""
"""Test that --endpoint overrides even with --disaggregation-mode prefill."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--endpoint",
"dyn://custom.worker.serve",
"--is-prefill-worker",
"--disaggregation-mode",
"prefill",
)
config = parse_args()
assert config.namespace == "custom"
......@@ -216,3 +219,86 @@ def test_headless_namespace_has_required_fields(mock_vllm_cli):
# Core engine fields must survive the round-trip
assert hasattr(ns, "model")
assert hasattr(ns, "tensor_parallel_size")
# --disaggregation-mode tests
def test_disaggregation_mode_default(mock_vllm_cli):
"""Test that default disaggregation mode is AGGREGATED."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.AGGREGATED
assert config.is_prefill_worker is False
assert config.is_decode_worker is False
def test_disaggregation_mode_prefill(mock_vllm_cli):
"""Test --disaggregation-mode prefill sets correct state."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--disaggregation-mode", "prefill")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.PREFILL
assert config.is_prefill_worker is True
assert config.is_decode_worker is False
assert config.component == "prefill"
def test_disaggregation_mode_decode(mock_vllm_cli):
"""Test --disaggregation-mode decode sets correct state."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--disaggregation-mode", "decode")
config = parse_args()
assert config.disaggregation_mode == DisaggregationMode.DECODE
assert config.is_prefill_worker is False
assert config.is_decode_worker is True
def test_legacy_is_prefill_worker_emits_deprecation(mock_vllm_cli):
"""Test that --is-prefill-worker still works but emits DeprecationWarning."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--is-prefill-worker")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config = parse_args()
deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert len(deprecation_warnings) >= 1
assert "deprecated" in str(deprecation_warnings[0].message).lower()
assert config.disaggregation_mode == DisaggregationMode.PREFILL
assert config.is_prefill_worker is True
def test_legacy_is_decode_worker_emits_deprecation(mock_vllm_cli):
"""Test that --is-decode-worker still works but emits DeprecationWarning."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B", "--is-decode-worker")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
config = parse_args()
deprecation_warnings = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert len(deprecation_warnings) >= 1
assert "deprecated" in str(deprecation_warnings[0].message).lower()
assert config.disaggregation_mode == DisaggregationMode.DECODE
assert config.is_decode_worker is True
def test_conflicting_legacy_and_new_flags_raises(mock_vllm_cli):
"""Test that combining legacy flags with explicit --disaggregation-mode raises ValueError."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--disaggregation-mode",
"prefill",
"--is-decode-worker",
)
with pytest.raises(ValueError, match="Cannot combine"):
parse_args()
def test_explicit_default_mode_with_legacy_flag_raises(mock_vllm_cli):
"""Test that --disaggregation-mode agg --is-decode-worker raises ValueError."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--disaggregation-mode",
"agg",
"--is-decode-worker",
)
with pytest.raises(ValueError, match="Cannot combine"):
parse_args()
......@@ -13,6 +13,7 @@ from dynamo.llm import ModelInput
from dynamo.runtime import DistributedRuntime
from .args import Config
from .constants import DisaggregationMode
from .multimodal_handlers import (
EncodeWorkerHandler,
MultimodalDecodeWorkerHandler,
......@@ -149,7 +150,7 @@ class WorkerFactory:
# Set up decode worker client for disaggregated mode
decode_worker_client = None
if config.is_prefill_worker:
if config.disaggregation_mode == DisaggregationMode.PREFILL:
decode_worker_client = await runtime.endpoint(
f"{config.namespace}.decoder.generate"
).client()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment