Unverified Commit de27efe6 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: Migrate vllm configuration system (#6075)

parent b94f9dcd
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""ArgGroup implementations for different configuration domains."""
from .runtime_args import DynamoRuntimeArgGroup, DynamoRuntimeConfig
__all__ = ["DynamoRuntimeArgGroup", "DynamoRuntimeConfig"]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dynamo runtime configuration ArgGroup."""
from typing import Optional
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
class DynamoRuntimeConfig(ConfigBase):
"""Configuration for Dynamo runtime (common across all backends)."""
namespace: str
store_kv: str
request_plane: str
event_plane: str
connector: list[str]
enable_local_indexer: bool
durable_kv_events: bool
dyn_tool_call_parser: Optional[str] = None
dyn_reasoning_parser: Optional[str] = None
custom_jinja_template: Optional[str] = None
endpoint_types: str
dump_config_to: Optional[str] = None
def validate(self) -> None:
# TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events
class DynamoRuntimeArgGroup(ArgGroup):
"""Dynamo runtime configuration parameters (common to all backends)."""
def add_arguments(self, parser) -> None:
"""Add Dynamo runtime arguments to parser."""
g = parser.add_argument_group("Dynamo Runtime Options")
add_argument(
g,
flag_name="--namespace",
env_var="DYN_NAMESPACE",
default="dynamo",
help="Dynamo namespace",
)
add_argument(
g,
flag_name="--store-kv",
env_var="DYN_STORE_KV",
default="etcd",
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENDPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
choices=["etcd", "file", "mem"],
)
add_argument(
g,
flag_name="--request-plane",
env_var="DYN_REQUEST_PLANE",
default="tcp",
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest.",
choices=["tcp", "nats", "http"],
)
add_argument(
g,
flag_name="--event-plane",
env_var="DYN_EVENT_PLANE",
default="nats",
help="Determines how events are published.",
choices=["nats", "zmq"],
)
add_argument(
g,
flag_name="--connector",
env_var="DYN_CONNECTOR",
default=["nixl"],
help="List of connectors to use in order (e.g., --connector nixl lmcache). Options: nixl, lmcache, kvbm, null, none. Order will be preserved in MultiConnector.",
nargs="*",
)
add_negatable_bool_argument(
g,
flag_name="--durable-kv-events",
env_var="DYN_DURABLE_KV_EVENTS",
default=False,
help="Enable durable KV events using NATS JetStream instead of the local indexer. By default, local indexer is enabled for lower latency. Use this flag when you need durability and multi-replica router consistency. Requires NATS with JetStream enabled. Can also be set via DYN_DURABLE_KV_EVENTS=true env var.",
)
# Optional: tool/reasoning parsers (choices from dynamo._core when available)
# To avoid name conflicts with different backends, prefix "dyn-" for dynamo specific args
add_argument(
g,
flag_name="--dyn-tool-call-parser",
env_var="DYN_TOOL_CALL_PARSER",
default=None,
help="Tool call parser name for the model.",
choices=get_tool_parser_names(),
)
add_argument(
g,
flag_name="--dyn-reasoning-parser",
env_var="DYN_REASONING_PARSER",
default=None,
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
choices=get_reasoning_parser_names(),
)
add_argument(
g,
flag_name="--custom-jinja-template",
env_var="DYN_CUSTOM_JINJA_TEMPLATE",
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
add_argument(
g,
flag_name="--endpoint-types",
env_var="DYN_ENDPOINT_TYPES",
default="chat,completions",
obsolete_flag="--dyn-endpoint-types",
help="Comma-separated list of endpoint types to enable. Options: 'chat', 'completions'. Use 'completions' for models without chat templates.",
)
add_argument(
g,
flag_name="--dump-config-to",
env_var="DYN_DUMP_CONFIG_TO",
default=None,
help="Dump resolved configuration to the specified file path.",
)
...@@ -40,6 +40,9 @@ def env_or_default(env_var: str, default: T) -> T: ...@@ -40,6 +40,9 @@ def env_or_default(env_var: str, default: T) -> T:
return int(value) # type: ignore return int(value) # type: ignore
elif isinstance(default, float): elif isinstance(default, float):
return float(value) # type: ignore return float(value) # type: ignore
elif isinstance(default, list):
# Env vars for list options (e.g. DYN_CONNECTOR) are space-separated; downstream expects a list.
return [x.strip() for x in value.split() if x.strip()] # type: ignore
else: else:
return value # type: ignore return value # type: ignore
...@@ -75,7 +78,11 @@ def add_argument( ...@@ -75,7 +78,11 @@ def add_argument(
names = [flag_name] names = [flag_name]
env_help = _build_help_message(help, env_var, default_with_env, obsolete_flag) if obsolete_flag:
# Accept obsolete flag as an alias (still show deprecation note in help)
names.append(obsolete_flag)
env_help = _build_help_message(help, env_var, default, obsolete_flag)
add_arg_opts = { add_arg_opts = {
"dest": arg_dest, "dest": arg_dest,
...@@ -126,7 +133,7 @@ def _build_help_message( ...@@ -126,7 +133,7 @@ def _build_help_message(
Build help message with env var and default value. Build help message with env var and default value.
""" """
if obsolete_flag: if obsolete_flag:
return f"{help_text}\nenv var: {env_var} | default: {default}\nobsolete flag: {obsolete_flag}" return f"{help_text}\nenv var: {env_var} | default: {default}\ndeprecating flag: {obsolete_flag}"
return f"{help_text}\nenv var: {env_var} | default: {default}" return f"{help_text}\nenv var: {env_var} | default: {default}"
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse
import logging import logging
import os import os
import socket import socket
...@@ -16,10 +16,14 @@ try: ...@@ -16,10 +16,14 @@ try:
except ImportError: except ImportError:
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names from dynamo.common.config_dump import register_encoder
from dynamo.common.config_dump import add_config_dump_args, register_encoder from dynamo.common.configuration.groups.runtime_args import (
DynamoRuntimeArgGroup,
DynamoRuntimeConfig,
)
from dynamo.vllm.backend_args import DynamoVllmArgGroup, DynamoVllmConfig
from . import __version__, envs from . import envs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,11 +31,7 @@ DEFAULT_MODEL = "Qwen/Qwen3-0.6B" ...@@ -27,11 +31,7 @@ DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"} VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
class Config: class Config(DynamoRuntimeConfig, DynamoVllmConfig):
"""Command line parameters or defaults"""
# dynamo specific
namespace: str
component: str component: str
endpoint: str endpoint: str
is_prefill_worker: bool is_prefill_worker: bool
...@@ -41,57 +41,18 @@ class Config: ...@@ -41,57 +41,18 @@ class Config:
request_plane: str request_plane: str
event_plane: str event_plane: str
enable_local_indexer: bool = True enable_local_indexer: bool = True
use_kv_events: bool
# mirror vLLM # mirror vLLM
model: str model: str
served_model_name: Optional[str] served_model_name: Optional[str] = None
# rest vLLM args # rest vLLM args
engine_args: AsyncEngineArgs engine_args: AsyncEngineArgs
# Connector list from CLI def validate(self) -> None:
connector_list: Optional[list] = None DynamoRuntimeConfig.validate(self)
DynamoVllmConfig.validate(self)
# tool and reasoning parser info
tool_call_parser: Optional[str] = None
reasoning_parser: Optional[str] = None
# endpoint types to enable
dyn_endpoint_types: str = "chat,completions"
# multimodal options
multimodal_processor: bool = False
# Embedding Cache Processor is different from the regular processor
# TODO: Have a single processor for all cases and adopting rust based processor
ec_processor: bool = False
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
multimodal_decode_worker: bool = False
enable_multimodal: bool = False
multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
frontend_decoding: bool = False
# vLLM-native encoder worker (ECConnector mode)
vllm_native_encoder_worker: bool = False
ec_connector_backend: Optional[str] = "ECExampleConnector"
ec_storage_path: Optional[str] = None
ec_extra_config: Optional[str] = None
ec_consumer_mode: bool = False
# vLLM-Omni worker for multi-stage pipelines
omni: bool = False
# Path to vLLM-Omni stage configuration YAML
stage_configs_path: Optional[str] = None
# dump config to file
dump_config_to: Optional[str] = None
# Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False
# Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args)
use_kv_events: bool = False
def has_connector(self, connector_name: str) -> bool: def has_connector(self, connector_name: str) -> bool:
""" """
...@@ -103,7 +64,7 @@ class Config: ...@@ -103,7 +64,7 @@ class Config:
Returns: Returns:
True if the connector is in the connector list, False otherwise True if the connector is in the connector list, False otherwise
""" """
return self.connector_list is not None and connector_name in self.connector_list return self.connector is not None and connector_name in self.connector
@register_encoder(Config) @register_encoder(Config)
...@@ -118,404 +79,247 @@ def parse_args() -> Config: ...@@ -118,404 +79,247 @@ def parse_args() -> Config:
Returns: Returns:
Config: Parsed configuration object. Config: Parsed configuration object.
""" """
parser = FlexibleArgumentParser(
description="vLLM server integrated with Dynamo LLM." dynamo_runtime_argspec = DynamoRuntimeArgGroup()
) dynamo_vllm_argspec = DynamoVllmArgGroup()
parser.add_argument(
"--version", action="version", version=f"Dynamo Backend VLLM {__version__}" parser = argparse.ArgumentParser(
) description="Dynamo vLLM worker configuration",
parser.add_argument( formatter_class=argparse.RawTextHelpFormatter,
"--is-prefill-worker",
action="store_true",
help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate",
)
parser.add_argument(
"--is-decode-worker",
action="store_true",
help="Mark this as a decode worker which does not publish KV events.",
)
parser.add_argument(
"--connector",
nargs="*",
default=["nixl"],
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
# To avoid name conflicts with different backends, adopted prefix "dyn-" for dynamo specific args
parser.add_argument(
"--dyn-tool-call-parser",
type=str,
default=None,
choices=get_tool_parser_names(),
help="Tool call parser name for the model.",
)
parser.add_argument(
"--dyn-reasoning-parser",
type=str,
default=None,
choices=get_reasoning_parser_names(),
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
)
parser.add_argument(
"--custom-jinja-template",
type=str,
default=None,
help="Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository.",
)
parser.add_argument(
"--dyn-endpoint-types",
type=str,
default="chat,completions",
help="Comma-separated list of endpoint types to enable. Options: 'chat', 'completions'. Default: 'chat,completions'. Use 'completions' for models without chat templates.",
)
parser.add_argument(
"--multimodal-processor",
action="store_true",
help="Run as multimodal processor component for handling multimodal requests",
)
parser.add_argument(
"--ec-processor",
action="store_true",
help="Run as ECConnector processor (routes multimodal requests to encoder then PD workers)",
)
parser.add_argument(
"--multimodal-encode-worker",
action="store_true",
help="Run as multimodal encode worker component for processing images/videos",
)
parser.add_argument(
"--multimodal-worker",
action="store_true",
help="Run as multimodal worker component for LLM inference with multimodal data",
)
parser.add_argument(
"--multimodal-decode-worker",
action="store_true",
help="Run as multimodal decode worker in disaggregated mode",
)
parser.add_argument(
"--multimodal-encode-prefill-worker",
action="store_true",
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4)",
)
parser.add_argument(
"--enable-multimodal",
action="store_true",
help="Enable multimodal processing. If not set, none of the multimodal components can be used.",
)
parser.add_argument(
"--mm-prompt-template",
type=str,
default="USER: <image>\n<prompt> ASSISTANT:",
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
parser.add_argument(
"--frontend-decoding",
action="store_true",
help=(
"Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
parser.add_argument(
"--vllm-native-encoder-worker",
action="store_true",
help="Run as vLLM-native encoder worker using ECConnector for encoder disaggregation (requires shared storage). The following flags only work when this flag is enabled: --ec-connector-backend, --ec-storage-path, --ec-extra-config, --ec-consumer-mode.",
)
parser.add_argument(
"--ec-connector-backend",
type=str,
default="ECExampleConnector",
help="ECConnector implementation class for encoder disaggregation. Default: ECExampleConnector (disk-based)",
)
parser.add_argument(
"--ec-storage-path",
type=str,
default=None,
help="Storage path for ECConnector (required for ECExampleConnector, optional for other backends)",
)
parser.add_argument(
"--ec-extra-config",
type=str,
default=None,
help="Additional ECConnector configuration as JSON string",
)
parser.add_argument(
"--ec-consumer-mode",
action="store_true",
help="Configure as ECConnector consumer for receiving encoder embeddings (for PD workers)",
)
parser.add_argument(
"--omni",
action="store_true",
help="Run as vLLM-Omni worker for multi-stage pipelines (supports text-to-text, text-to-image, etc.)",
)
parser.add_argument(
"--stage-configs-path",
type=str,
default=None,
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
)
parser.add_argument(
"--store-kv",
type=str,
choices=["etcd", "file", "mem"],
default=os.environ.get("DYN_STORE_KV", "etcd"),
help="Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENDPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
)
parser.add_argument(
"--request-plane",
type=str,
choices=["nats", "http", "tcp"],
default=os.environ.get("DYN_REQUEST_PLANE", "tcp"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--event-plane",
type=str,
choices=["nats", "zmq"],
default=os.environ.get("DYN_EVENT_PLANE", "nats"),
help="Determines how events are published [nats|zmq]",
)
parser.add_argument(
"--durable-kv-events",
action="store_true",
dest="durable_kv_events",
default=os.environ.get("DYN_DURABLE_KV_EVENTS", "false").lower() == "true",
help="Enable durable KV events using NATS JetStream instead of the local indexer. By default, local indexer is enabled for lower latency. Use this flag when you need durability and multi-replica router consistency. Requires NATS with JetStream enabled. Can also be set via DYN_DURABLE_KV_EVENTS=true env var.",
) )
parser.add_argument(
"--use-vllm-tokenizer", # Build argument parser
action="store_true", dynamo_runtime_argspec.add_arguments(parser)
default=False, dynamo_vllm_argspec.add_arguments(parser)
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
# trick to add vllm engine flags to a specific group without breaking the Dynamo groups.
vg = parser.add_argument_group(
"vLLM Engine Options. Please refer to vLLM documentation for more details."
) )
add_config_dump_args(parser) vllm_parser = FlexibleArgumentParser(add_help=False)
AsyncEngineArgs.add_cli_args(vllm_parser, async_args_only=False)
for action in vllm_parser._actions:
if not action.option_strings:
continue
vg._group_actions.append(action)
args, unknown = parser.parse_known_args()
dynamo_config = Config.from_cli_args(args)
# Validate arguments
dynamo_config.validate()
vllm_args = vllm_parser.parse_args(unknown)
# Set the model name from the command line arguments
# model is defined in AsyncEngineArgs, but when AsyncEngineArgs.from_cli_args is called,
# vllm will update the model name to the full path of the model, which will break the dynamo logic,
# as we use the model name as served_model_name (if served_model_name is not set)
dynamo_config.model = vllm_args.model
engine_config = AsyncEngineArgs.from_cli_args(vllm_args)
cross_validate_config(dynamo_config, engine_config)
update_dynamo_config_with_engine(dynamo_config, engine_config)
update_engine_config_with_dynamo(dynamo_config, engine_config)
dynamo_config.engine_args = engine_config
return dynamo_config
parser = AsyncEngineArgs.add_cli_args(parser) def cross_validate_config(
args = parser.parse_args() dynamo_config: Config, engine_config: AsyncEngineArgs
engine_args = AsyncEngineArgs.from_cli_args(args) ) -> None:
"""Validate dynamo and engine config together. This should not modify the configs."""
if hasattr(engine_args, "stream_interval") and engine_args.stream_interval != 1: if hasattr(engine_config, "stream_interval") and engine_config.stream_interval != 1:
logger.warning( logger.warning(
"--stream-interval is currently not respected in Dynamo. " "--stream-interval is currently not respected in Dynamo. "
"Dynamo uses its own post-processing implementation on the frontend, " "Dynamo uses its own post-processing implementation on the frontend, "
"bypassing vLLM's OutputProcessor buffering. " "bypassing vLLM's OutputProcessor buffering."
) )
# Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor.
# With TP=1, vLLM defaults to UniProcExecutor which runs scheduler and worker in the same
# process. This causes a hot loop in _process_engine_step that doesn't release the GIL,
# blocking NIXL's add_remote_agent from completing. Using "mp" backend forces separate
# processes, avoiding the GIL contention.
# Note: Only apply for NIXL - other connectors (kvbm, lmcache) work fine with UniProcExecutor
# and forcing mp can expose race conditions in vLLM's scheduler.
# See: https://github.com/vllm-project/vllm/issues/29369
connector_list = [c.lower() for c in args.connector] if args.connector else []
uses_nixl = "nixl" in connector_list
tp_size = getattr(engine_args, "tensor_parallel_size", None) or 1
if uses_nixl and tp_size == 1 and engine_args.distributed_executor_backend is None:
logger.info(
"Setting --distributed-executor-backend=mp for TP=1 to avoid "
"UniProcExecutor GIL contention with NIXL connector"
)
engine_args.distributed_executor_backend = "mp"
if engine_args.enable_prefix_caching is None:
logger.debug(
"--enable-prefix-caching or --no-enable-prefix-caching not specified. Defaulting to True (vLLM v1 default behavior)"
)
engine_args.enable_prefix_caching = True
config = Config()
config.model = args.model
if args.served_model_name:
assert (
len(args.served_model_name) <= 1
), "We do not support multiple model names."
config.served_model_name = args.served_model_name[0]
else:
# This becomes an `Option` on the Rust side
config.served_model_name = None
config.namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
# Check multimodal role exclusivity
mm_flags = (
int(bool(args.multimodal_processor))
+ int(bool(args.ec_processor))
+ int(bool(args.multimodal_encode_worker))
+ int(bool(args.multimodal_worker))
+ int(bool(args.multimodal_decode_worker))
+ int(bool(args.multimodal_encode_prefill_worker))
+ int(bool(args.vllm_native_encoder_worker))
)
if mm_flags > 1:
raise ValueError(
"Use only one of --multimodal-processor, --ec-processor, --multimodal-encode-worker, --multimodal-worker, "
"--multimodal-decode-worker, --multimodal-encode-prefill-worker, or --vllm-native-encoder-worker"
)
if mm_flags == 1 and not args.enable_multimodal: def update_dynamo_config_with_engine(
raise ValueError("Use --enable-multimodal to enable multimodal processing") dynamo_config: Config, engine_config: AsyncEngineArgs
) -> None:
"""Update dynamo_config fields from engine_config and worker flags."""
# Validate vLLM-native encoder worker config if getattr(engine_config, "served_model_name", None) is not None:
if args.vllm_native_encoder_worker: served = engine_config.served_model_name
if ( if len(served) > 1:
args.ec_connector_backend == "ECExampleConnector" raise ValueError("We do not support multiple model names.")
and not args.ec_storage_path dynamo_config.served_model_name = served[0]
): else:
raise ValueError( dynamo_config.served_model_name = None
"--ec-storage-path is required when using ECExampleConnector backend. "
"Specify a shared storage path for encoder cache."
)
# Validate omni worker requirements
if args.stage_configs_path and not args.omni:
raise ValueError(
"--stage-configs-path is only allowed when using --omni. "
"Specify a YAML file containing stage configurations for the multi-stage pipeline."
)
# Set component and endpoint based on worker type # TODO: move to "disaggregation_mode" as the other engines.
if args.multimodal_processor or args.ec_processor: if dynamo_config.multimodal_processor or dynamo_config.ec_processor:
config.component = "processor" dynamo_config.component = "processor"
config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif ( elif (
args.vllm_native_encoder_worker dynamo_config.vllm_native_encoder_worker
or args.multimodal_encode_worker or dynamo_config.multimodal_encode_worker
or args.multimodal_encode_prefill_worker or dynamo_config.multimodal_encode_prefill_worker
): ):
config.component = "encoder" dynamo_config.component = "encoder"
config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif args.multimodal_decode_worker: elif dynamo_config.multimodal_decode_worker:
# Uses "decoder" component name because prefill worker connects to "decoder" dynamo_config.component = "decoder"
# (prefill uses "backend" to receive from encoder) dynamo_config.endpoint = "generate"
config.component = "decoder" elif dynamo_config.multimodal_worker and dynamo_config.is_prefill_worker:
config.endpoint = "generate" dynamo_config.component = "backend"
elif args.multimodal_worker and args.is_prefill_worker: dynamo_config.endpoint = "generate"
# Multimodal prefill worker stays as "backend" to maintain encoder connection elif dynamo_config.omni:
config.component = "backend" dynamo_config.component = "backend"
config.endpoint = "generate" dynamo_config.endpoint = "generate"
elif args.omni: elif dynamo_config.is_prefill_worker:
# Omni worker uses "backend" component for multi-stage pipeline orchestration dynamo_config.component = "prefill"
config.component = "backend" dynamo_config.endpoint = "generate"
config.endpoint = "generate"
elif args.is_prefill_worker:
config.component = "prefill"
config.endpoint = "generate"
else: else:
config.component = "backend" dynamo_config.component = "backend"
config.endpoint = "generate" dynamo_config.endpoint = "generate"
config.engine_args = engine_args if dynamo_config.custom_jinja_template is not None:
config.is_prefill_worker = args.is_prefill_worker
config.is_decode_worker = args.is_decode_worker
config.tool_call_parser = args.dyn_tool_call_parser
config.reasoning_parser = args.dyn_reasoning_parser
config.custom_jinja_template = args.custom_jinja_template
config.dyn_endpoint_types = args.dyn_endpoint_types
config.multimodal_processor = args.multimodal_processor
config.ec_processor = args.ec_processor
config.multimodal_encode_worker = args.multimodal_encode_worker
config.multimodal_worker = args.multimodal_worker
config.multimodal_decode_worker = args.multimodal_decode_worker
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.enable_multimodal = args.enable_multimodal
config.mm_prompt_template = args.mm_prompt_template
config.frontend_decoding = args.frontend_decoding
config.vllm_native_encoder_worker = args.vllm_native_encoder_worker
config.ec_connector_backend = args.ec_connector_backend
config.ec_storage_path = args.ec_storage_path
config.ec_extra_config = args.ec_extra_config
config.ec_consumer_mode = args.ec_consumer_mode
config.omni = args.omni
config.stage_configs_path = args.stage_configs_path
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.event_plane = args.event_plane
config.enable_local_indexer = not args.durable_kv_events
# For omni mode, use vLLM (AsyncOmni) tokenizer on backend
config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni
# use_kv_events is set later in overwrite_args() based on kv_events_config
# Validate custom Jinja template file exists if provided
if config.custom_jinja_template is not None:
# Expand environment variables and user home (~) before validation
expanded_template_path = os.path.expanduser( expanded_template_path = os.path.expanduser(
os.path.expandvars(config.custom_jinja_template) os.path.expandvars(dynamo_config.custom_jinja_template)
) )
config.custom_jinja_template = expanded_template_path dynamo_config.custom_jinja_template = expanded_template_path
if not os.path.isfile(expanded_template_path): if not os.path.isfile(expanded_template_path):
raise FileNotFoundError( raise FileNotFoundError(
f"Custom Jinja template file not found: {expanded_template_path}. " f"Custom Jinja template file not found: {expanded_template_path}. "
f"Please ensure the file exists and the path is correct." "Please ensure the file exists and the path is correct."
) )
normalized = [c.lower() for c in args.connector] normalized = [c.lower() for c in (dynamo_config.connector or [])]
invalid = [c for c in normalized if c not in VALID_CONNECTORS] invalid = [c for c in normalized if c not in VALID_CONNECTORS]
if invalid: if invalid:
raise ValueError( raise ValueError(
f"Invalid connector(s): {', '.join(invalid)}. Valid options are: {', '.join(sorted(VALID_CONNECTORS))}" f"Invalid connector(s): {', '.join(invalid)}. "
f"Valid options are: {', '.join(sorted(VALID_CONNECTORS))}"
) )
# Check for custom kv_transfer_config
has_kv_transfer_config = ( has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config") hasattr(engine_config, "kv_transfer_config")
and engine_args.kv_transfer_config is not None and engine_config.kv_transfer_config is not None
) )
if not normalized or "none" in normalized or "null" in normalized: if not normalized or "none" in normalized or "null" in normalized:
if len(normalized) > 1: if len(normalized) > 1:
raise ValueError( raise ValueError(
"'none' and 'null' cannot be combined with other connectors" "'none' and 'null' cannot be combined with other connectors"
) )
config.connector_list = [] dynamo_config.connector = [] # type: ignore[assignment]
else: else:
# Check for conflicting flags
if has_kv_transfer_config: if has_kv_transfer_config:
raise ValueError( raise ValueError(
"Cannot specify both --kv-transfer-config and --connector flags" "Cannot specify both --kv-transfer-config and --connector flags"
) )
dynamo_config.connector = normalized # type: ignore[assignment]
config.connector_list = normalized
if config.engine_args.block_size is None: def update_engine_config_with_dynamo(
config.engine_args.block_size = 16 dynamo_config: Config, engine_config: AsyncEngineArgs
) -> None:
"""Update engine config base on Dynamo config."""
# Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor.
# With TP=1, vLLM defaults to UniProcExecutor which runs scheduler and worker in the same
# process. This causes a hot loop in _process_engine_step that doesn't release the GIL,
# blocking NIXL's add_remote_agent from completing. Using "mp" backend forces separate
# processes, avoiding the GIL contention.
# Note: Only apply for NIXL - other connectors (kvbm, lmcache) work fine with UniProcExecutor
# and forcing mp can expose race conditions in vLLM's scheduler.
# See: https://github.com/vllm-project/vllm/issues/29369
connector_list = (
[c.lower() for c in dynamo_config.connector] if dynamo_config.connector else []
)
uses_nixl = "nixl" in connector_list
tp_size = getattr(engine_config, "tensor_parallel_size", None) or 1
if (
uses_nixl
and tp_size == 1
and getattr(engine_config, "distributed_executor_backend", None) is None
):
logger.info(
"Setting --distributed-executor-backend=mp for TP=1 to avoid "
"UniProcExecutor GIL contention with NIXL connector"
)
engine_config.distributed_executor_backend = "mp"
if engine_config.enable_prefix_caching is None:
logger.debug( logger.debug(
f"Setting reasonable default of {config.engine_args.block_size} for block_size" "--enable-prefix-caching or --no-enable-prefix-caching not specified. "
"Defaulting to True (vLLM v1 default behavior)"
) )
engine_config.enable_prefix_caching = True
config.dump_config_to = args.dump_config_to if getattr(engine_config, "block_size", None) is None:
engine_config.block_size = 16
logger.debug(
f"Setting reasonable default of {engine_config.block_size} for block_size"
)
return config if dynamo_config.has_connector("nixl") or (
# Check if the user provided their own kv_transfer_config
getattr(engine_config, "kv_transfer_config", None) is not None
# and the connector is NixlConnector
and engine_config.kv_transfer_config.kv_connector == "NixlConnector"
):
ensure_side_channel_host()
defaults = {
# vLLM 0.13+ renamed 'task' to 'runner'
"runner": "generate",
# As of vLLM >=0.10.0 the engine unconditionally calls
# `sampling_params.update_from_tokenizer(...)`, so we can no longer
# skip tokenizer initialisation. Setting this to **False** avoids
# a NoneType error when the processor accesses the tokenizer.
"skip_tokenizer_init": False,
"enable_log_requests": False,
"disable_log_stats": False,
}
def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]: kv_transfer_config = create_kv_transfer_config(dynamo_config, engine_config)
if kv_transfer_config:
defaults["kv_transfer_config"] = kv_transfer_config
kv_cfg = create_kv_events_config(dynamo_config, engine_config)
defaults["kv_events_config"] = kv_cfg
dynamo_config.use_kv_events = kv_cfg is not None and kv_cfg.enable_kv_cache_events
logger.info(
f"Using kv_events_config for publishing vLLM kv events over zmq: {kv_cfg} "
f"(use_kv_events={dynamo_config.use_kv_events})"
)
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(engine_config, key):
setattr(engine_config, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
logger.debug(
f" Skipping engine_args.{key} (not available in this vLLM version)"
)
def create_kv_events_config(
dynamo_config: Config, engine_config: AsyncEngineArgs
) -> Optional[KVEventsConfig]:
"""Create KVEventsConfig for prefix caching if needed.""" """Create KVEventsConfig for prefix caching if needed."""
if config.is_decode_worker: if dynamo_config.is_decode_worker:
logger.info( logger.info(
f"Decode worker detected (is_decode_worker={config.is_decode_worker}): " f"Decode worker detected (is_decode_worker={dynamo_config.is_decode_worker}): "
f"kv_events_config disabled (decode workers don't publish KV events)" "kv_events_config disabled (decode workers don't publish KV events)",
dynamo_config.is_decode_worker,
) )
return None return None
# If prefix caching is not enabled, no events config needed # If prefix caching is not enabled, no events config needed
if not config.engine_args.enable_prefix_caching: if not engine_config.enable_prefix_caching:
logger.info("No kv_events_config required: prefix caching is disabled") logger.info("No kv_events_config required: prefix caching is disabled")
return None return None
# If user provided their own config, use that # If user provided their own config, use that
if c := getattr(config.engine_args, "kv_events_config"): if c := getattr(engine_config, "kv_events_config"):
# Warn user that enable_kv_cache_events probably should be True (user may have omitted it from JSON)
if not c.enable_kv_cache_events: if not c.enable_kv_cache_events:
logger.warning( logger.warning(
"User provided --kv_events_config which set enable_kv_cache_events to False (default). " "User provided --kv_events_config which set enable_kv_cache_events to False (default). "
...@@ -525,12 +329,12 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]: ...@@ -525,12 +329,12 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]:
return c return c
# Create default events config for prefix caching # Create default events config for prefix caching
# TODO: move this to configuration system.
port = envs.DYN_VLLM_KV_EVENT_PORT port = envs.DYN_VLLM_KV_EVENT_PORT
logger.info( logger.info(
f"Using env-var DYN_VLLM_KV_EVENT_PORT={port} to create kv_events_config" f"Using env-var DYN_VLLM_KV_EVENT_PORT={port} to create kv_events_config"
) )
dp_rank = config.engine_args.data_parallel_rank or 0 dp_rank = engine_config.data_parallel_rank or 0
return KVEventsConfig( return KVEventsConfig(
enable_kv_cache_events=True, enable_kv_cache_events=True,
publisher="zmq", publisher="zmq",
...@@ -538,40 +342,40 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]: ...@@ -538,40 +342,40 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]:
) )
def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]: def create_kv_transfer_config(
dynamo_config: Config, engine_config: AsyncEngineArgs
) -> Optional[KVTransferConfig]:
"""Create KVTransferConfig based on user config or connector list. """Create KVTransferConfig based on user config or connector list.
Handles logging and returns the appropriate config or None. Handles logging and returns the appropriate config or None.
""" """
has_user_kv_config = ( has_user_kv_config = (
hasattr(config.engine_args, "kv_transfer_config") hasattr(engine_config, "kv_transfer_config")
and config.engine_args.kv_transfer_config is not None and engine_config.kv_transfer_config is not None
) )
if has_user_kv_config: if has_user_kv_config:
logger.info("Using user-provided kv_transfer_config from --kv-transfer-config") logger.info("Using user-provided kv_transfer_config from --kv-transfer-config")
return None # Let vLLM use the user's config return None
if not dynamo_config.connector:
# No connector list or empty list means no config
if not config.connector_list:
logger.info("Using vLLM defaults for kv_transfer_config") logger.info("Using vLLM defaults for kv_transfer_config")
return None return None
logger.info(
logger.info(f"Creating kv_transfer_config from --connector {config.connector_list}") f"Creating kv_transfer_config from --connector {dynamo_config.connector}"
)
# Create connector configs in specified order
multi_connectors = [] multi_connectors = []
for connector in config.connector_list: for conn in dynamo_config.connector:
if connector == "lmcache": if conn == "lmcache":
connector_cfg = {"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"} connector_cfg = {"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"}
elif connector == "nixl": elif conn == "nixl":
connector_cfg = {"kv_connector": "NixlConnector", "kv_role": "kv_both"} connector_cfg = {"kv_connector": "NixlConnector", "kv_role": "kv_both"}
elif connector == "kvbm": elif conn == "kvbm":
connector_cfg = { connector_cfg = {
"kv_connector": "DynamoConnector", "kv_connector": "DynamoConnector",
"kv_connector_module_path": "kvbm.vllm_integration.connector", "kv_connector_module_path": "kvbm.vllm_integration.connector",
"kv_role": "kv_both", "kv_role": "kv_both",
} }
else:
continue
multi_connectors.append(connector_cfg) multi_connectors.append(connector_cfg)
# For single connector, return direct config # For single connector, return direct config
...@@ -588,54 +392,6 @@ def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]: ...@@ -588,54 +392,6 @@ def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]:
) )
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
if config.has_connector("nixl") or (
# Check if the user provided their own kv_transfer_config
config.engine_args.kv_transfer_config is not None
# and the connector is NixlConnector
and config.engine_args.kv_transfer_config.kv_connector == "NixlConnector"
):
ensure_side_channel_host()
defaults = {
# vLLM 0.13+ renamed 'task' to 'runner'
"runner": "generate",
# As of vLLM >=0.10.0 the engine unconditionally calls
# `sampling_params.update_from_tokenizer(...)`, so we can no longer
# skip tokenizer initialisation. Setting this to **False** avoids
# a NoneType error when the processor accesses the tokenizer.
"skip_tokenizer_init": False,
"enable_log_requests": False,
"disable_log_stats": False,
}
kv_transfer_config = create_kv_transfer_config(config)
if kv_transfer_config:
defaults["kv_transfer_config"] = kv_transfer_config
kv_cfg = create_kv_events_config(config)
defaults["kv_events_config"] = kv_cfg
# Derive use_kv_events from whether kv_events_config is set AND enable_kv_cache_events is True
config.use_kv_events = kv_cfg is not None and kv_cfg.enable_kv_cache_events
logger.info(
f"Using kv_events_config for publishing vLLM kv events over zmq: {kv_cfg} "
f"(use_kv_events={config.use_kv_events})"
)
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
logger.debug(
f" Skipping engine_args.{key} (not available in this vLLM version)"
)
def get_host_ip() -> str: def get_host_ip() -> str:
"""Get the IP address of the host for side-channel coordination.""" """Get the IP address of the host for side-channel coordination."""
try: try:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dynamo vLLM wrapper configuration ArgGroup."""
from typing import Optional
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from . import __version__
class DynamoVllmArgGroup(ArgGroup):
"""vLLM-specific Dynamo wrapper configuration (not native vLLM engine args)."""
name = "dynamo-vllm"
def add_arguments(self, parser) -> None:
"""Add Dynamo vLLM arguments to parser."""
parser.add_argument(
"--version", action="version", version=f"Dynamo Backend VLLM {__version__}"
)
g = parser.add_argument_group("Dynamo vLLM Options")
add_negatable_bool_argument(
g,
flag_name="--is-prefill-worker",
env_var="DYN_VLLM_IS_PREFILL_WORKER",
default=False,
help="Enable prefill functionality for this worker. Uses the provided namespace to construct dyn://namespace.prefill.generate",
)
add_negatable_bool_argument(
g,
flag_name="--is-decode-worker",
env_var="DYN_VLLM_IS_DECODE_WORKER",
default=False,
help="Mark this as a decode worker which does not publish KV events",
)
add_negatable_bool_argument(
g,
flag_name="--use-vllm-tokenizer",
env_var="DYN_VLLM_USE_TOKENIZER",
default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
)
add_argument(
g,
flag_name="--sleep-mode-level",
env_var="DYN_VLLM_SLEEP_MODE_LEVEL",
default=1,
help="Sleep mode level (1=offload to CPU, 2=discard weights, 3=discard all).",
choices=[1, 2, 3],
arg_type=int,
)
# Multimodal
add_negatable_bool_argument(
g,
flag_name="--multimodal-processor",
env_var="DYN_VLLM_MULTIMODAL_PROCESSOR",
default=False,
help="Run as multimodal processor component for handling multimodal requests.",
)
add_negatable_bool_argument(
g,
flag_name="--ec-processor",
env_var="DYN_VLLM_EC_PROCESSOR",
default=False,
help="Run as ECConnector processor (routes multimodal requests to encoder then PD workers).",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-encode-worker",
env_var="DYN_VLLM_MULTIMODAL_ENCODE_WORKER",
default=False,
help="Run as multimodal encode worker component for processing images/videos.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-worker",
env_var="DYN_VLLM_MULTIMODAL_WORKER",
default=False,
help="Run as multimodal worker component for LLM inference with multimodal data.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-decode-worker",
env_var="DYN_VLLM_MULTIMODAL_DECODE_WORKER",
default=False,
help="Run as multimodal decode worker in disaggregated mode.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-encode-prefill-worker",
env_var="DYN_VLLM_MULTIMODAL_ENCODE_PREFILL_WORKER",
default=False,
help="Run as unified encode+prefill+decode worker for models requiring integrated image encoding (e.g., Llama 4).",
)
add_negatable_bool_argument(
g,
flag_name="--enable-multimodal",
env_var="DYN_VLLM_ENABLE_MULTIMODAL",
default=False,
help="Enable multimodal processing. If not set, none of the multimodal components can be used.",
)
add_argument(
g,
flag_name="--mm-prompt-template",
env_var="DYN_VLLM_MM_PROMPT_TEMPLATE",
default="USER: <image>\n<prompt> ASSISTANT:",
help=(
"Different multi-modal models expect the prompt to contain different special media prompts. "
"The processor will use this argument to construct the final prompt. "
"User prompt will replace '<prompt>' in the provided template. "
"For example, if the user prompt is 'please describe the image' and the prompt template is "
"'USER: <image> <prompt> ASSISTANT:', the resulting prompt is "
"'USER: <image> please describe the image ASSISTANT:'."
),
)
add_negatable_bool_argument(
g,
flag_name="--frontend-decoding",
env_var="DYN_VLLM_FRONTEND_DECODING",
default=False,
help=(
"Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
# vLLM-native encoder (ECConnector)
add_negatable_bool_argument(
g,
flag_name="--vllm-native-encoder-worker",
env_var="DYN_VLLM_NATIVE_ENCODER_WORKER",
default=False,
help="Run as vLLM-native encoder worker using ECConnector for encoder disaggregation (requires shared storage). The following flags only work when this flag is enabled: --ec-connector-backend, --ec-storage-path, --ec-extra-config, --ec-consumer-mode.",
)
add_argument(
g,
flag_name="--ec-connector-backend",
env_var="DYN_VLLM_EC_CONNECTOR_BACKEND",
default="ECExampleConnector",
help="ECConnector implementation class for encoder disaggregation.",
)
add_argument(
g,
flag_name="--ec-storage-path",
env_var="DYN_VLLM_EC_STORAGE_PATH",
default=None,
help="Storage path for ECConnector (required for ECExampleConnector, optional for other backends).",
)
add_argument(
g,
flag_name="--ec-extra-config",
env_var="DYN_VLLM_EC_EXTRA_CONFIG",
default=None,
help="Additional ECConnector configuration as JSON string.",
)
add_negatable_bool_argument(
g,
flag_name="--ec-consumer-mode",
env_var="DYN_VLLM_EC_CONSUMER_MODE",
default=False,
help="Configure as ECConnector consumer for receiving encoder embeddings (for PD workers).",
)
# vLLM-Omni
add_negatable_bool_argument(
g,
flag_name="--omni",
env_var="DYN_VLLM_OMNI",
default=False,
help="Run as vLLM-Omni worker for multi-stage pipelines (supports text-to-text, text-to-image, etc.).",
)
add_argument(
g,
flag_name="--stage-configs-path",
env_var="DYN_VLLM_STAGE_CONFIGS_PATH",
default=None,
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
)
# @dataclass()
class DynamoVllmConfig(ConfigBase):
"""Configuration for Dynamo vLLM wrapper (vLLM-specific only). All fields optional."""
is_prefill_worker: bool
is_decode_worker: bool
use_vllm_tokenizer: bool
sleep_mode_level: int
# Multimodal
multimodal_processor: bool
ec_processor: bool
multimodal_encode_worker: bool
multimodal_worker: bool
multimodal_decode_worker: bool
multimodal_encode_prefill_worker: bool
enable_multimodal: bool
mm_prompt_template: str
frontend_decoding: bool
# vLLM-native encoder (ECConnector)
vllm_native_encoder_worker: bool
ec_connector_backend: str
ec_storage_path: Optional[str] = None
ec_extra_config: Optional[str] = None
ec_consumer_mode: bool
# vLLM-Omni
omni: bool
stage_configs_path: Optional[str] = None
def validate(self) -> None:
"""Validate vLLM wrapper configuration."""
self._validate_prefill_decode_exclusive()
self._validate_multimodal_role_exclusivity()
self._validate_multimodal_requires_flag()
self._validate_ec_connector_storage()
self._validate_omni_stage_config()
def _validate_prefill_decode_exclusive(self) -> None:
"""Ensure at most one of is_prefill_worker and is_decode_worker is set."""
if self.is_prefill_worker and self.is_decode_worker:
raise ValueError(
"Cannot set both --is-prefill-worker and --is-decode-worker"
)
def _count_multimodal_roles(self) -> int:
"""Return the number of multimodal roles set (0 or 1 allowed)."""
return sum(
[
bool(self.multimodal_processor),
bool(self.ec_processor),
bool(self.multimodal_encode_worker),
bool(self.multimodal_worker),
bool(self.multimodal_decode_worker),
bool(self.multimodal_encode_prefill_worker),
bool(self.vllm_native_encoder_worker),
]
)
def _validate_multimodal_role_exclusivity(self) -> None:
"""Ensure only one multimodal role is set at a time."""
if self._count_multimodal_roles() > 1:
raise ValueError(
"Only one multimodal role can be set at a time: "
"multimodal-processor, ec-processor, multimodal-encode-worker, "
"multimodal-worker, multimodal-decode-worker, "
"multimodal-encode-prefill-worker, vllm-native-encoder-worker"
)
def _validate_multimodal_requires_flag(self) -> None:
"""Require --enable-multimodal when any multimodal role is set."""
if self._count_multimodal_roles() == 1 and not self.enable_multimodal:
raise ValueError(
"Use --enable-multimodal when enabling any multimodal component"
)
def _validate_ec_connector_storage(self) -> None:
"""Require ec_storage_path when using ECExampleConnector backend."""
if self.vllm_native_encoder_worker:
if (
self.ec_connector_backend == "ECExampleConnector"
and not self.ec_storage_path
):
raise ValueError(
"--ec-storage-path is required when using ECExampleConnector backend. "
"Specify a shared storage path for encoder cache."
)
def _validate_omni_stage_config(self) -> None:
"""Require stage_configs_path when using --omni."""
if self.stage_configs_path and not self.omni:
raise ValueError(
"--stage-configs-path is only allowed when using --omni. "
"Specify a YAML file containing stage configurations for the multi-stage pipeline."
)
...@@ -12,6 +12,8 @@ import os ...@@ -12,6 +12,8 @@ import os
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
# TODO: move this to configuration system.
# Port range constants # Port range constants
REGISTERED_PORT_MIN = 1024 REGISTERED_PORT_MIN = 1024
REGISTERED_PORT_MAX = 49151 REGISTERED_PORT_MAX = 49151
......
...@@ -55,7 +55,7 @@ from dynamo.vllm.multimodal_handlers import ( ...@@ -55,7 +55,7 @@ from dynamo.vllm.multimodal_handlers import (
) )
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
from .args import Config, overwrite_args, parse_args from .args import Config, parse_args
from .chrek import get_checkpoint_config from .chrek import get_checkpoint_config
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import ( from .health_check import (
...@@ -99,7 +99,6 @@ async def graceful_shutdown(runtime, shutdown_event): ...@@ -99,7 +99,6 @@ async def graceful_shutdown(runtime, shutdown_event):
async def worker(): async def worker():
config = parse_args() config = parse_args()
overwrite_args(config)
dump_config(config.dump_config_to, config) dump_config(config.dump_config_to, config)
# Name the model. Use either the full path (vllm and sglang do the same), # Name the model. Use either the full path (vllm and sglang do the same),
...@@ -494,8 +493,8 @@ async def register_vllm_model( ...@@ -494,8 +493,8 @@ async def register_vllm_model(
# Add tool/reasoning parsers for decode models # Add tool/reasoning parsers for decode models
if model_type != ModelType.Prefill: if model_type != ModelType.Prefill:
runtime_config.tool_call_parser = config.tool_call_parser runtime_config.tool_call_parser = config.dyn_tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1) # Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
...@@ -785,14 +784,14 @@ async def init( ...@@ -785,14 +784,14 @@ async def init(
await _handle_non_leader_node(config.engine_args.data_parallel_rank) await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return return
# Parse endpoint types from --dyn-endpoint-types flag # Parse endpoint types from --endpoint-types flag
model_type = parse_endpoint_types(config.dyn_endpoint_types) model_type = parse_endpoint_types(config.endpoint_types)
logger.info(f"Registering model with endpoint types: {config.dyn_endpoint_types}") logger.info(f"Registering model with endpoint types: {config.endpoint_types}")
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
# Warn if custom template provided but chat endpoint not enabled # Warn if custom template provided but chat endpoint not enabled
if config.custom_jinja_template and "chat" not in config.dyn_endpoint_types: if config.custom_jinja_template and "chat" not in config.endpoint_types:
logger.warning( logger.warning(
"Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. " "Custom Jinja template provided (--custom-jinja-template) but 'chat' not in --dyn-endpoint-types. "
"The chat template will be loaded but the /v1/chat/completions endpoint will not be available." "The chat template will be loaded but the /v1/chat/completions endpoint will not be available."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment