"docs/components/vscode:/vscode.git/clone" did not exist on "a3e6468df08c7760ded88120e90dfd9519b2bdeb"
args.py 19 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import contextlib
import logging
import os
import socket
import sys
9
import tempfile
10
11
12
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
13
from pathlib import Path
14
from typing import Any, Dict, Generator, List, Optional
15

16
import yaml
17
from sglang.srt.server_args import ServerArgs
18
from sglang.srt.server_args_config_parser import ConfigArgumentMerger
19

20
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
21
from dynamo.common.config_dump import register_encoder
22
from dynamo.llm import fetch_llm
23
from dynamo.runtime.logging import configure_dynamo_logging
24
25
from dynamo.sglang import __version__

26
27
configure_dynamo_logging()

28
29
30
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"

31
32
33
34
35
36
37
38
39
40
41
42
DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
    "endpoint": {
        "flags": ["--endpoint"],
        "type": str,
        "help": f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Example: {DEFAULT_ENDPOINT}",
    },
    "migration-limit": {
        "flags": ["--migration-limit"],
        "type": int,
        "default": 0,
        "help": "Maximum number of times a request may be migrated to a different engine worker",
    },
43
44
45
46
    "tool-call-parser": {
        "flags": ["--dyn-tool-call-parser"],
        "type": str,
        "default": None,
47
48
        "choices": get_tool_parser_names(),
        "help": "Tool call parser name for the model.",
49
50
51
52
53
    },
    "reasoning-parser": {
        "flags": ["--dyn-reasoning-parser"],
        "type": str,
        "default": None,
54
        "choices": get_reasoning_parser_names(),
55
        "help": "Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
56
    },
57
58
59
60
61
62
    "custom-jinja-template": {
        "flags": ["--custom-jinja-template"],
        "type": str,
        "default": None,
        "help": "Path to a custom Jinja template file to override the model's default chat template. This template will take precedence over any template found in the model repository. This template will be applied by Dynamo's preprocessor and cannot be used with --use-sglang-tokenizer.",
    },
63
64
65
66
    "use-sglang-tokenizer": {
        "flags": ["--use-sglang-tokenizer"],
        "action": "store_true",
        "default": False,
67
        "help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend. Cannot be used with --custom-jinja-template.",
68
    },
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    "multimodal-processor": {
        "flags": ["--multimodal-processor"],
        "action": "store_true",
        "default": False,
        "help": "Run as multimodal processor component for handling multimodal requests",
    },
    "multimodal-encode-worker": {
        "flags": ["--multimodal-encode-worker"],
        "action": "store_true",
        "default": False,
        "help": "Run as multimodal encode worker component for processing images/videos",
    },
    "multimodal-worker": {
        "flags": ["--multimodal-worker"],
        "action": "store_true",
        "default": False,
        "help": "Run as multimodal worker component for LLM inference with multimodal data",
    },
87
88
89
90
91
92
    "embedding-worker": {
        "flags": ["--embedding-worker"],
        "action": "store_true",
        "default": False,
        "help": "Run as embedding worker component (Dynamo flag, also sets SGLang's --is-embedding)",
    },
93
94
95
96
97
98
    "dump-config-to": {
        "flags": ["--dump-config-to"],
        "type": str,
        "default": None,
        "help": "Dump debug config to the specified file path. If not specified, the config will be dumped to stdout at INFO level.",
    },
99
100
101
    "store-kv": {
        "flags": ["--store-kv"],
        "type": str,
102
        "choices": ["etcd", "file", "mem"],
103
104
105
        "default": os.environ.get("DYN_STORE_KV", "etcd"),
        "help": "Which key-value backend to use: etcd, mem, file. Etcd uses the ETCD_* env vars (e.g. ETCD_ENPOINTS) for connection details. File uses root dir from env var DYN_FILE_KV or defaults to $TMPDIR/dynamo_store_kv.",
    },
106
107
108
109
110
111
112
    "request-plane": {
        "flags": ["--request-plane"],
        "type": str,
        "choices": ["nats", "http", "tcp"],
        "default": os.environ.get("DYN_REQUEST_PLANE", "nats"),
        "help": "Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
    },
113
114
115
116
117
118
119
120
121
}


@dataclass
class DynamoArgs:
    namespace: str
    component: str
    endpoint: str
    migration_limit: int
122
    store_kv: str
123
    request_plane: str
124

125
126
127
    # tool and reasoning parser options
    tool_call_parser: Optional[str] = None
    reasoning_parser: Optional[str] = None
128
    custom_jinja_template: Optional[str] = None
129

130
131
132
    # preprocessing options
    use_sglang_tokenizer: bool = False

133
134
135
136
137
    # multimodal options
    multimodal_processor: bool = False
    multimodal_encode_worker: bool = False
    multimodal_worker: bool = False

138
139
    # embedding options
    embedding_worker: bool = False
140
141
    # config dump options
    dump_config_to: Optional[str] = None
142

143
144
145
146
147
148
149
150

class DisaggregationMode(Enum):
    AGGREGATED = "agg"
    PREFILL = "prefill"
    DECODE = "decode"


class Config:
151
152
    """Combined configuration container for SGLang server and Dynamo args."""

153
154
155
156
157
158
159
160
161
162
163
164
    def __init__(self, server_args: ServerArgs, dynamo_args: DynamoArgs) -> None:
        self.server_args = server_args
        self.dynamo_args = dynamo_args
        self.serving_mode = self._set_serving_strategy()

    def _set_serving_strategy(self):
        if self.server_args.disaggregation_mode == "null":
            return DisaggregationMode.AGGREGATED
        elif self.server_args.disaggregation_mode == "prefill":
            return DisaggregationMode.PREFILL
        elif self.server_args.disaggregation_mode == "decode":
            return DisaggregationMode.DECODE
165
166
        else:
            return DisaggregationMode.AGGREGATED
167
168


169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Register SGLang-specific encoders with the shared system
@register_encoder(Config)
def _preprocess_for_encode_config(
    config: Config,
) -> Dict[str, Any]:  # pyright: ignore[reportUnusedFunction]
    return {
        "server_args": config.server_args,
        "dynamo_args": config.dynamo_args,
        "serving_mode": config.serving_mode.value
        if config.serving_mode is not None
        else "None",
    }


183
184
185
186
187
def _set_parser(
    sglang_str: Optional[str],
    dynamo_str: Optional[str],
    arg_name: str = "tool-call-parser",
) -> Optional[str]:
188
189
190
191
192
193
194
195
196
197
198
199
200
    """Resolve parser name from SGLang and Dynamo arguments.

    Args:
        sglang_str: Parser value from SGLang argument.
        dynamo_str: Parser value from Dynamo argument.
        arg_name: Name of the parser argument for logging.

    Returns:
        Resolved parser name, preferring Dynamo's value if both set.

    Raises:
        ValueError: If parser name is not valid.
    """
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    # If both are present, give preference to dynamo_str
    if sglang_str is not None and dynamo_str is not None:
        logging.warning(
            f"--dyn-{arg_name} and --{arg_name} are both set. Giving preference to --dyn-{arg_name}"
        )
        return dynamo_str
    # If dynamo_str is not set, use try to use sglang_str if it matches with the allowed parsers
    elif sglang_str is not None:
        logging.warning(f"--dyn-{arg_name} is not set. Using --{arg_name}.")
        if arg_name == "tool-call-parser" and sglang_str not in get_tool_parser_names():
            raise ValueError(
                f"--{arg_name} is not a valid tool call parser. Valid parsers are: {get_tool_parser_names()}"
            )
        elif (
            arg_name == "reasoning-parser"
            and sglang_str not in get_reasoning_parser_names()
        ):
            raise ValueError(
                f"--{arg_name} is not a valid reasoning parser. Valid parsers are: {get_reasoning_parser_names()}"
            )
        return sglang_str
    else:
        return dynamo_str


226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def _extract_config_section(
    args: List[str], config_path: str, config_key: str
) -> tuple[List[str], str]:
    """
    Extract a section from nested YAML and create temp flat file.

    Args:
        args: CLI arguments list
        config_path: Path to the YAML config file
        config_key: Key to extract from nested YAML

    Returns:
        tuple: (modified args with temp file path, temp file path for cleanup)

    Raises:
        ValueError: If config file not found, key missing, or invalid format
    """
    logging.info(f"Extracting config section '{config_key}' from {config_path}")

    path = Path(config_path)
    if not path.exists():
        raise ValueError(f"Config file not found: {config_path}")

    with open(config_path, "r") as f:
        config_data = yaml.safe_load(f)

    if not isinstance(config_data, dict):
        raise ValueError(
            f"Config file must contain a dictionary, got {type(config_data).__name__}"
        )

    available_keys = list(config_data.keys())
    logging.info(f"Available config keys in {config_path}: {available_keys}")

    if config_key not in config_data:
        raise ValueError(
            f"Config key '{config_key}' not found in {config_path}. "
            f"Available keys: {available_keys}"
        )

    section_data = config_data[config_key]

    if not isinstance(section_data, dict):
        raise ValueError(
            f"Config section '{config_key}' must be a dictionary, got {type(section_data).__name__}"
        )

    temp_fd, temp_path = tempfile.mkstemp(suffix=".yaml", prefix="dynamo_config_")

    try:
        with os.fdopen(temp_fd, "w") as f:
            yaml.dump(section_data, f)
        logging.info(f"Successfully wrote config section '{config_key}' to temp file")
    except Exception:
        os.unlink(temp_path)
        raise

    config_index = args.index("--config")
    args = list(args)
    args[config_index + 1] = temp_path

    return args, temp_path


290
async def parse_args(args: list[str]) -> Config:
291
    """Parse CLI arguments and return combined configuration.
292
    Download the model if necessary.
293
294
295
296
297
298
299
300
301

    Args:
        args: Command-line argument strings.

    Returns:
        Config object with server_args and dynamo_args.

    Raises:
        SystemExit: If arguments are invalid or incompatible.
302
303
304
305
306
307
308
309
310
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--version", action="version", version=f"Dynamo Backend SGLang {__version__}"
    )

    # Dynamo args
    for info in DYNAMO_ARGS.values():
311
312
313
314
315
316
317
318
319
320
321
322
        kwargs = {
            "default": info["default"] if "default" in info else None,
            "help": info["help"],
        }
        if "type" in info:
            kwargs["type"] = info["type"]
        if "choices" in info:
            kwargs["choices"] = info["choices"]
        if "action" in info:
            kwargs["action"] = info["action"]

        parser.add_argument(*info["flags"], **kwargs)
323

324
325
326
327
328
329
330
331
    # Config key argument (for nested configs)
    parser.add_argument(
        "--config-key",
        type=str,
        default=None,
        help="Key to select from nested config file (e.g., 'prefill', 'decode')",
    )

332
333
334
335
    # SGLang args
    bootstrap_port = _reserve_disaggregation_bootstrap_port()
    ServerArgs.add_cli_args(parser)

336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    # Handle config file if present
    temp_config_file = None  # Track temp file for cleanup
    if "--config" in args:
        # Check if --config-key is also present
        if "--config-key" in args:
            key_index = args.index("--config-key")
            config_key = args[key_index + 1]
            config_index = args.index("--config")
            config_path = args[config_index + 1]

            # Extract nested section to temp file
            args, temp_config_file = _extract_config_section(
                args, config_path, config_key
            )

            # Remove --config-key from args (not recognized by SGLang)
            args = args[:key_index] + args[key_index + 2 :]

        # Extract boolean actions from the parser to handle them correctly in YAML
        boolean_actions = []
        for action in parser._actions:
            if hasattr(action, "dest") and hasattr(action, "action"):
                if action.action in ["store_true", "store_false"]:
                    boolean_actions.append(action.dest)

        # Merge config file arguments with CLI arguments
        config_merger = ConfigArgumentMerger(boolean_actions=boolean_actions)
        args = config_merger.merge_config_with_args(args)

365
366
    parsed_args = parser.parse_args(args)

367
368
369
370
371
372
373
    # Clean up temp file if created
    if temp_config_file and os.path.exists(temp_config_file):
        try:
            os.unlink(temp_config_file)
        except Exception:
            logging.warning(f"Failed to clean up temp config file: {temp_config_file}")

374
375
376
377
378
379
380
381
382
    # Auto-set bootstrap port if not provided
    if not any(arg.startswith("--disaggregation-bootstrap-port") for arg in args):
        args_dict = vars(parsed_args)
        args_dict["disaggregation_bootstrap_port"] = bootstrap_port
        parsed_args = Namespace(**args_dict)

    # Dynamo argument processing
    # If an endpoint is provided, validate and use it
    # otherwise fall back to default endpoints
383
    namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
384

385
386
387
388
    # If --embedding-worker is set, also set SGLang's --is-embedding flag
    if parsed_args.embedding_worker:
        parsed_args.is_embedding = True

389
390
    endpoint = parsed_args.endpoint
    if endpoint is None:
391
392
393
        if parsed_args.embedding_worker:
            endpoint = f"dyn://{namespace}.backend.generate"
        elif (
394
395
396
397
            hasattr(parsed_args, "disaggregation_mode")
            and parsed_args.disaggregation_mode == "prefill"
        ):
            endpoint = f"dyn://{namespace}.prefill.generate"
398
399
400
401
402
403
404
405
406
        elif parsed_args.multimodal_processor:
            endpoint = f"dyn://{namespace}.processor.generate"
        elif parsed_args.multimodal_encode_worker:
            endpoint = f"dyn://{namespace}.encoder.generate"
        elif (
            parsed_args.multimodal_worker
            and parsed_args.disaggregation_mode == "prefill"
        ):
            endpoint = f"dyn://{namespace}.prefill.generate"
407
408
409
410
411
412
413
414
415
416
417
418
419
420
        else:
            endpoint = f"dyn://{namespace}.backend.generate"

    # Always parse the endpoint (whether auto-generated or user-provided)
    endpoint_str = endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        logging.error(
            f"Invalid endpoint format: '{endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        sys.exit(1)

    parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts

421
422
423
424
425
426
427
428
429
430
431
    tool_call_parser = _set_parser(
        parsed_args.tool_call_parser,
        parsed_args.dyn_tool_call_parser,
        "tool-call-parser",
    )
    reasoning_parser = _set_parser(
        parsed_args.reasoning_parser,
        parsed_args.dyn_reasoning_parser,
        "reasoning-parser",
    )

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    if parsed_args.custom_jinja_template and parsed_args.use_sglang_tokenizer:
        logging.error(
            "Cannot use --custom-jinja-template and --use-sglang-tokenizer together. "
            "--custom-jinja-template requires Dynamo's preprocessor to apply the template, "
            "while --use-sglang-tokenizer bypasses Dynamo's preprocessor entirely."
            "If you want to use the SGLang tokenizer with a custom chat template, "
            "please use the --chat-template argument from SGLang."
        )
        sys.exit(1)

    # Replaces any environment variables or home dir (~) to get absolute path
    expanded_template_path = None
    if parsed_args.custom_jinja_template:
        expanded_template_path = os.path.expandvars(
            os.path.expanduser(parsed_args.custom_jinja_template)
        )
448
449
450
451
452
        # Validate custom Jinja template file exists
        if not os.path.isfile(expanded_template_path):
            raise FileNotFoundError(
                f"Custom Jinja template file not found: {expanded_template_path}"
            )
453

454
455
456
457
458
    dynamo_args = DynamoArgs(
        namespace=parsed_namespace,
        component=parsed_component_name,
        endpoint=parsed_endpoint_name,
        migration_limit=parsed_args.migration_limit,
459
        store_kv=parsed_args.store_kv,
460
        request_plane=parsed_args.request_plane,
461
462
        tool_call_parser=tool_call_parser,
        reasoning_parser=reasoning_parser,
463
        custom_jinja_template=expanded_template_path,
464
        use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
465
466
467
        multimodal_processor=parsed_args.multimodal_processor,
        multimodal_encode_worker=parsed_args.multimodal_encode_worker,
        multimodal_worker=parsed_args.multimodal_worker,
468
        embedding_worker=parsed_args.embedding_worker,
469
        dump_config_to=parsed_args.dump_config_to,
470
471
472
    )
    logging.debug(f"Dynamo args: {dynamo_args}")

473
474
475
476
477
478
479
480
    # TODO: sglang downloads the model in `from_cli_args`, so we need to do it here.
    # That's unfortunate because `parse_args` isn't the right place for this. Fix.
    model_path = parsed_args.model_path
    if not parsed_args.served_model_name:
        parsed_args.served_model_name = model_path
    if not os.path.exists(model_path):
        parsed_args.model_path = await fetch_llm(model_path)

481
482
    server_args = ServerArgs.from_cli_args(parsed_args)

483
484
485
486
487
488
489
490
    if parsed_args.use_sglang_tokenizer:
        logging.info(
            "Using SGLang's built in tokenizer. Setting skip_tokenizer_init to False"
        )
        server_args.skip_tokenizer_init = False
    else:
        logging.info(
            "Using dynamo's built in tokenizer. Setting skip_tokenizer_init to True"
491
492
493
        )
        server_args.skip_tokenizer_init = True

494
495
496
497
    return Config(server_args, dynamo_args)


@contextlib.contextmanager
498
499
500
501
502
503
504
505
def reserve_free_port(host: str = "localhost") -> Generator[int, None, None]:
    """Find and reserve a free port until context exits.

    Args:
        host: Host address to bind to.

    Yields:
        Available port number.
506
507
508
509
510
511
512
513
514
515
    """
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    try:
        sock.bind((host, 0))
        _, port = sock.getsockname()
        yield port
    finally:
        sock.close()


516
def parse_endpoint(endpoint: str) -> List[str]:
517
518
519
520
521
522
523
524
525
526
527
    """Parse endpoint string into namespace, component, and endpoint parts.

    Args:
        endpoint: Endpoint string in 'dyn://namespace.component.endpoint' format.

    Returns:
        List of [namespace, component, endpoint] strings.

    Raises:
        ValueError: If endpoint format is invalid.
    """
528
529
530
531
532
533
534
535
536
537
538
539
540
    endpoint_str = endpoint.replace("dyn://", "", 1)
    endpoint_parts = endpoint_str.split(".")
    if len(endpoint_parts) != 3:
        error_msg = (
            f"Invalid endpoint format: '{endpoint}'. "
            f"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
        )
        logging.error(error_msg)
        raise ValueError(error_msg)

    return endpoint_parts


541
542
543
544
545
def _reserve_disaggregation_bootstrap_port() -> int:
    """Reserve a unique port for disaggregation bootstrap.

    Returns:
        Available port number.
546
547
548
    """
    with reserve_free_port() as port:
        return port