args.py 8.1 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
# SPDX-License-Identifier: Apache-2.0

import argparse
5
import ipaddress
6
7
8
9
10
11
12
13
14
15
16
17
import logging
import os
import socket
import sys
from typing import Callable, List, Optional, Tuple

from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs

logger = logging.getLogger(__name__)

18
19
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113


class Config:
    """Command line parameters or defaults"""

    # dynamo specific
    namespace: str
    component: str
    endpoint: str
    kv_port: Optional[int] = None

    # mirror vLLM
    model: str
    served_model_name: Optional[str]

    # rest vLLM args
    engine_args: AsyncEngineArgs


def parse_endpoint(endpoint: str) -> List[str]:
    endpoint_str = endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        logger.error(
            f"Invalid endpoint format: '{endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        sys.exit(1)

    return endpoint_parts


def base_parse_args(
    parser: argparse.ArgumentParser, endpoint_overwrite: Optional[Callable] = None
) -> Tuple[argparse.Namespace, Config]:
    """
    Basic parsing logic for any dynamo vLLM deployment. The caller will use
    'parser' and 'endpoint_overwrite' to apply use case specific customization.

    Args:
        parser (argparse.ArgumentParser): The argument parser which has use case
            specific arguments added.
        endpoint_overwrite (Callable): A user provided function to overwrite the endpoints
            the given the parsed arguments. This function should return the overwritten args.
            A typical selector will check the worker type and return specific endpoints.

    Returns:
        Tuple[argparse.Namespace, Config]: A tuple containing the parsed arguments
            and a Config object with the relevant settings.
    """
    if not any(arg.dest == "endpoint" for arg in parser._actions):
        parser.add_argument(
            "--endpoint",
            type=str,
            default=DEFAULT_ENDPOINT,
            help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
        )
    parser = AsyncEngineArgs.add_cli_args(parser)
    args = parser.parse_args()
    engine_args = AsyncEngineArgs.from_cli_args(args)

    config = Config()
    config.model = args.model
    if args.served_model_name:
        assert (
            len(args.served_model_name) <= 1
        ), "We do not support multiple model names."
        config.served_model_name = args.served_model_name[0]
    else:
        # This becomes an `Option` on the Rust side
        config.served_model_name = None

    if endpoint_overwrite is not None:
        args = endpoint_overwrite(args)

    endpoint = args.endpoint

    parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
        endpoint
    )

    config.namespace = parsed_namespace
    config.component = parsed_component_name
    config.endpoint = parsed_endpoint_name
    config.engine_args = engine_args

    if config.engine_args.block_size is None:
        config.engine_args.block_size = 16
        logger.debug(
            f"Setting reasonable default of {config.engine_args.block_size} for block_size"
        )

    return args, config


114
115
116
def get_kv_port() -> int:
    """Get KV events port from environment or default."""
    return int(os.getenv("DYN_VLLM_KV_EVENT_PORT", "20080"))
117

118

119
def ensure_side_channel_host():
120
121
122
123
124
    """Ensure the NIXL side-channel host is available without overriding user settings.

    Uses hostname resolution with UDP connect fallback. Supports IPv4 and IPv6.
    Raises RuntimeError if no routable IP can be determined.
    """
125
126
    existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
    if existing_host:
127
        logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
128
        return
129

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def is_routable(ip_str: str) -> bool:
        try:
            addr = ipaddress.ip_address(ip_str)
            return not (
                addr.is_loopback
                or addr.is_link_local
                or addr.is_unspecified
                or addr.is_multicast
            )
        except ValueError:
            return False

    # Strategy 1: hostname resolution (AF_UNSPEC for IPv4+IPv6)
    host_ip = None
    detection_method = None
145
146
    try:
        host_name = socket.gethostname()
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        infos = socket.getaddrinfo(
            host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
        )
        for family, socktype, _, _, sockaddr in infos:
            candidate = sockaddr[0]
            try:
                with socket.socket(family, socktype) as s:
                    s.bind((candidate, 0))
                if is_routable(candidate):
                    host_ip = candidate
                    detection_method = "hostname resolution"
                    break
            except OSError:
                continue
    except OSError as exc:
        logger.debug("Hostname resolution failed: %s", exc)

    # Strategy 2: UDP connect trick (IPv4 then IPv6)
    if not host_ip:
        for family, target, label in [
            (socket.AF_INET, ("8.8.8.8", 80), "outbound interface detection (IPv4)"),
            (
                socket.AF_INET6,
                ("2001:4860:4860::8888", 80),
                "outbound interface detection (IPv6)",
            ),
        ]:
            try:
                with socket.socket(family, socket.SOCK_DGRAM) as s:
                    s.connect(target)
                    candidate = s.getsockname()[0]
                if is_routable(candidate):
                    host_ip = candidate
                    detection_method = label
                    break
            except OSError:
                continue

    if not host_ip:
        raise RuntimeError(
            "Unable to determine a routable host IP for NIXL side-channel. "
            "Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
            "the IP address that peer nodes can reach this host on."
        )

    os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
    logger.info(
        "Set VLLM_NIXL_SIDE_CHANNEL_HOST=%s (detected via %s)",
        host_ip,
        detection_method,
    )
198
199


200
201
def configure_ports(config: Config):
    """Configure port settings from dedicated environment overrides."""
202

203
204
    # Always set kv_port as it's used by overwrite_args regardless of prefix caching
    config.kv_port = get_kv_port()
205

206
    ensure_side_channel_host()
207
208
209
210


def overwrite_args(config):
    """Set vLLM defaults for Dynamo."""
211
212
    if config.engine_args.enable_prefix_caching:
        assert config.kv_port is not None, "Must set the kv_port, use configure_ports"
213
214
215
216

    dp_rank = config.engine_args.data_parallel_rank or 0

    defaults = {
217
218
        # vLLM 0.13+ renamed 'task' to 'runner'
        "runner": "generate",
219
        "skip_tokenizer_init": False,
220
        "enable_log_requests": False,
221
222
223
        "enable_prefix_caching": True,
        # KV routing relies on logging KV metrics
        "disable_log_stats": False,
224
225
        # Enable multimodal embeddings input
        "enable_mm_embeds": True,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        # Always setting up kv transfer for disagg
        "kv_transfer_config": KVTransferConfig(
            kv_connector="NixlConnector", kv_role="kv_both"
        ),
        "kv_events_config": KVEventsConfig(
            enable_kv_cache_events=True,
            publisher="zmq",
            endpoint=f"tcp://*:{config.kv_port - dp_rank}",  # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
        ),
    }

    logger.debug("Setting Dynamo defaults for vLLM")
    for key, value in defaults.items():
        if hasattr(config.engine_args, key):
            setattr(config.engine_args, key, value)
            logger.debug(f" engine_args.{key} = {value}")
        else:
243
244
245
            logger.debug(
                f" Skipping engine_args.{key} (not available in this vLLM version)"
            )