Unverified Commit d392bbdd authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: `--config` file support in sglang (#4272)

parent 8379b0cd
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import contextlib import contextlib
import logging import logging
import os import os
import socket import socket
import sys import sys
import tempfile
from argparse import Namespace from argparse import Namespace
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional from typing import Any, Dict, Generator, List, Optional
import yaml
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args_config_parser import ConfigArgumentMerger
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.config_dump import register_encoder from dynamo.common.config_dump import register_encoder
...@@ -211,6 +214,70 @@ def _set_parser( ...@@ -211,6 +214,70 @@ def _set_parser(
return dynamo_str return dynamo_str
def _extract_config_section(
args: List[str], config_path: str, config_key: str
) -> tuple[List[str], str]:
"""
Extract a section from nested YAML and create temp flat file.
Args:
args: CLI arguments list
config_path: Path to the YAML config file
config_key: Key to extract from nested YAML
Returns:
tuple: (modified args with temp file path, temp file path for cleanup)
Raises:
ValueError: If config file not found, key missing, or invalid format
"""
logging.info(f"Extracting config section '{config_key}' from {config_path}")
path = Path(config_path)
if not path.exists():
raise ValueError(f"Config file not found: {config_path}")
with open(config_path, "r") as f:
config_data = yaml.safe_load(f)
if not isinstance(config_data, dict):
raise ValueError(
f"Config file must contain a dictionary, got {type(config_data).__name__}"
)
available_keys = list(config_data.keys())
logging.info(f"Available config keys in {config_path}: {available_keys}")
if config_key not in config_data:
raise ValueError(
f"Config key '{config_key}' not found in {config_path}. "
f"Available keys: {available_keys}"
)
section_data = config_data[config_key]
if not isinstance(section_data, dict):
raise ValueError(
f"Config section '{config_key}' must be a dictionary, got {type(section_data).__name__}"
)
temp_fd, temp_path = tempfile.mkstemp(suffix=".yaml", prefix="dynamo_config_")
try:
with os.fdopen(temp_fd, "w") as f:
yaml.dump(section_data, f)
logging.info(f"Successfully wrote config section '{config_key}' to temp file")
except Exception:
os.unlink(temp_path)
raise
config_index = args.index("--config")
args = list(args)
args[config_index + 1] = temp_path
return args, temp_path
async def parse_args(args: list[str]) -> Config: async def parse_args(args: list[str]) -> Config:
"""Parse CLI arguments and return combined configuration. """Parse CLI arguments and return combined configuration.
Download the model if necessary. Download the model if necessary.
...@@ -245,12 +312,56 @@ async def parse_args(args: list[str]) -> Config: ...@@ -245,12 +312,56 @@ async def parse_args(args: list[str]) -> Config:
parser.add_argument(*info["flags"], **kwargs) parser.add_argument(*info["flags"], **kwargs)
# Config key argument (for nested configs)
parser.add_argument(
"--config-key",
type=str,
default=None,
help="Key to select from nested config file (e.g., 'prefill', 'decode')",
)
# SGLang args # SGLang args
bootstrap_port = _reserve_disaggregation_bootstrap_port() bootstrap_port = _reserve_disaggregation_bootstrap_port()
ServerArgs.add_cli_args(parser) ServerArgs.add_cli_args(parser)
# Handle config file if present
temp_config_file = None # Track temp file for cleanup
if "--config" in args:
# Check if --config-key is also present
if "--config-key" in args:
key_index = args.index("--config-key")
config_key = args[key_index + 1]
config_index = args.index("--config")
config_path = args[config_index + 1]
# Extract nested section to temp file
args, temp_config_file = _extract_config_section(
args, config_path, config_key
)
# Remove --config-key from args (not recognized by SGLang)
args = args[:key_index] + args[key_index + 2 :]
# Extract boolean actions from the parser to handle them correctly in YAML
boolean_actions = []
for action in parser._actions:
if hasattr(action, "dest") and hasattr(action, "action"):
if action.action in ["store_true", "store_false"]:
boolean_actions.append(action.dest)
# Merge config file arguments with CLI arguments
config_merger = ConfigArgumentMerger(boolean_actions=boolean_actions)
args = config_merger.merge_config_with_args(args)
parsed_args = parser.parse_args(args) parsed_args = parser.parse_args(args)
# Clean up temp file if created
if temp_config_file and os.path.exists(temp_config_file):
try:
os.unlink(temp_config_file)
except Exception:
logging.warning(f"Failed to clean up temp config file: {temp_config_file}")
# Auto-set bootstrap port if not provided # Auto-set bootstrap port if not provided
if not any(arg.startswith("--disaggregation-bootstrap-port") for arg in args): if not any(arg.startswith("--disaggregation-bootstrap-port") for arg in args):
args_dict = vars(parsed_args) args_dict = vars(parsed_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment