config.py 12.9 KB
Newer Older
1
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import json
17
import logging
18
import math
19
import shlex
20
from typing import Optional
21

22
from pydantic import BaseModel
23

24
from dynamo.common.utils.paths import get_workspace_dir
25
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
26
27
28
29
30
31
32
33
34
35
36
37

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


38
class Container(BaseModel):
39
40
41
    image: Optional[str] = None
    workingDir: Optional[str] = None
    command: Optional[list[str]] = None
42
43
    args: Optional[list[str]] = None
    model_config = {"extra": "allow"}
44
45
46


class PodSpec(BaseModel):
47
48
    mainContainer: Optional[Container] = None
    model_config = {"extra": "allow"}
49
50
51


class ServiceResources(BaseModel):
52
53
    requests: Optional[dict[str, str | dict]] = None
    limits: Optional[dict[str, str | dict]] = None
54
55
56


class Service(BaseModel):
57
58
59
    replicas: Optional[int] = None
    resources: Optional[ServiceResources] = None
    extraPodSpec: Optional[PodSpec] = None
60
    subComponentType: Optional[str] = None
61
    model_config = {"extra": "allow"}
62
63
64
65


class Services(BaseModel):
    Frontend: Service
66
    model_config = {"extra": "allow"}
67
68
69
70


class Spec(BaseModel):
    services: dict[str, Service]
71
    model_config = {"extra": "allow"}
72
73
74
75


class Metadata(BaseModel):
    name: str
76
    model_config = {"extra": "allow"}
77
78
79
80
81


class Config(BaseModel):
    metadata: Metadata
    spec: Spec
82
    model_config = {"extra": "allow"}
83
84


85
86
87
88
class MultinodeConfig(BaseModel):
    nodeCount: int


89
class DgdPlannerServiceConfig(BaseModel):
90
91
92
93
94
95
    """Planner service configuration.

    Planner reads profiling data from a ConfigMap (planner-profile-data)
    automatically created and mounted by the profiler; no PVC dependencies
    """

96
97
98
99
100
101
    dynamoNamespace: str = "dynamo"  # placeholder
    componentType: str = "planner"
    replicas: int = 1
    extraPodSpec: PodSpec = PodSpec(
        mainContainer=Container(
            image="my-registry/dynamo-runtime:my-tag",  # placeholder
102
            workingDir=f"{get_workspace_dir()}/components/src/dynamo/planner",
103
            command=["python3", "-m", "dynamo.planner.planner_sla"],
104
105
106
107
108
109
            args=[],
        )
    )
    model_config = {"extra": "allow"}


110
111
112
113
def break_arguments(args: list[str] | None) -> list[str]:
    ans: list[str] = []
    if args is None:
        return ans
114
    if isinstance(args, str):
115
116
        # Use shlex.split to properly handle quoted arguments and JSON values
        ans = shlex.split(args)
117
118
    else:
        for arg in args:
119
            if arg is not None:
120
121
122
123
124
125
126
127
128
129
130
                # If the arg looks like it might be JSON (starts with { or [) or is already a single token,
                # don't split it further. Only split if it contains spaces AND doesn't look like JSON.
                if (
                    isinstance(arg, str)
                    and (" " in arg or "\t" in arg)
                    and not (arg.strip().startswith(("{", "[")))
                ):
                    # Use shlex.split to properly handle quoted arguments
                    ans.extend(shlex.split(arg))
                else:
                    ans.append(arg)
131
    return ans
132
133


134
135
136
137
138
139
140
141
142
143
def remove_valued_arguments(args: list[str], key: str) -> list[str]:
    """Remove a valued argument (e.g., --key value) from the arguments list if exists."""
    if key in args:
        idx = args.index(key)
        if idx + 1 < len(args):
            del args[idx : idx + 2]

    return args


144
145
146
147
148
149
150
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
151
152


153
154
155
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
156

157
158
159
160
161
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
162

163
164
165
166
167
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
168

169
    return idx
170
171


172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def parse_override_engine_args(args: list[str]) -> tuple[dict, list[str]]:
    """
    Parse and extract --override-engine-args from argument list.

    Returns:
        tuple: (override_dict, modified_args) where override_dict is the parsed JSON
               and modified_args is the args list with --override-engine-args removed
    """
    override_dict = {}
    try:
        idx = args.index("--override-engine-args")
        if idx + 1 < len(args):
            # Parse existing override
            override_dict = json.loads(args[idx + 1])
            # Remove the old override args
            del args[idx : idx + 2]
    except (ValueError, json.JSONDecodeError):
        pass  # No existing override or invalid JSON

    return override_dict, args


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int):
    """Helper function to set multinode configuration based on GPU count and GPUs per node."""
    if gpu_count <= num_gpus_per_node:
        # Single node: remove multinode configuration if present
        if (
            hasattr(worker_service, "multinode")
            and worker_service.multinode is not None
        ):
            worker_service.multinode = None
    else:
        # Multi-node: set nodeCount = math.ceil(gpu_count / num_gpus_per_node)
        node_count = math.ceil(gpu_count / num_gpus_per_node)
        if not hasattr(worker_service, "multinode") or worker_service.multinode is None:
            # Create multinode configuration if it doesn't exist
            worker_service.multinode = MultinodeConfig(nodeCount=node_count)
        else:
            # Handle both dict (from YAML) and MultinodeConfig object cases
            if isinstance(worker_service.multinode, dict):
                worker_service.multinode["nodeCount"] = node_count
            else:
                worker_service.multinode.nodeCount = node_count


217
def get_service_name_by_type(
218
    config: Config, backend: str, sub_component_type: SubComponentType
219
220
221
222
223
224
) -> str:
    """Helper function to get service name by subComponentType.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
225
        config: Configuration object
226
227
228
229
230
231
232
        backend: Backend name (e.g., "sglang", "vllm", "trtllm")
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE)

    Returns:
        The service name
    """
    # Check if config has the expected structure
233
    if not config.spec or not config.spec.services:
234
235
236
237
238
239
240
        # Fall back to default name if structure is unexpected
        if sub_component_type == SubComponentType.DECODE:
            return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
        else:
            return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Look through services to find one with matching subComponentType
241
    services = config.spec.services
242
    for service_name, service_config in services.items():
243
        if service_config.subComponentType == sub_component_type.value:
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            return service_name

    # Fall back to default component names
    if sub_component_type == SubComponentType.DECODE:
        default_name = WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
    else:
        default_name = WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Check if the default name exists in services
    if default_name in services:
        return default_name

    # Last resort: return the default name anyway
    return default_name


def get_worker_service_from_config(
261
    config: Config,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
    backend: str = "sglang",
    sub_component_type: SubComponentType = SubComponentType.DECODE,
):
    """Helper function to get a worker service from config.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
        config: Configuration dictionary
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE). Defaults to DECODE.

    Returns:
        The worker service from the configuration
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

    # Get the service name using the type-aware logic
    service_name = get_service_name_by_type(config, backend, sub_component_type)

    # Get the actual service from the config
286
    return config.spec.services[service_name]
287
288
289
290
291
292
293
294
295
296
297
298
299
300


def setup_worker_service_resources(
    worker_service, gpu_count: int, num_gpus_per_node: Optional[int] = None
):
    """Helper function to set up worker service resources (requests and limits)."""
    # Handle multinode configuration if num_gpus_per_node is provided
    if num_gpus_per_node is not None:
        set_multinode_config(worker_service, gpu_count, num_gpus_per_node)

    # Ensure resources exists
    if worker_service.resources is None:
        worker_service.resources = ServiceResources()

301
302
303
    # Ensure limits exists
    if worker_service.resources.limits is None:
        worker_service.resources.limits = {}
304

305
    # Calculate GPU value
306
307
308
309
310
311
    gpu_value = (
        min(gpu_count, num_gpus_per_node)
        if num_gpus_per_node is not None
        else gpu_count
    )

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
    def _update_resource_dict(resource_dict: dict[str, str], gpu_value: int):
        """Helper function to update gpu and custom rdma/ib fields in a resource dictionary.

        Args:
            resource_dict: The resource dictionary (either limits or requests) to update
            gpu_value: The GPU value to set
        """
        resource_dict["gpu"] = str(gpu_value)

        # also update custom rdma/ib if it exists (some cluster requires this)
        if "custom" in resource_dict:
            if isinstance(resource_dict["custom"], dict):
                if "rdma/ib" in resource_dict["custom"]:
                    resource_dict["custom"]["rdma/ib"] = str(gpu_value)

    # Update limits
    _update_resource_dict(worker_service.resources.limits, gpu_value)
    # Also update requests if they exist
    if worker_service.resources.requests is not None:
        _update_resource_dict(worker_service.resources.requests, gpu_value)
332
333


334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def validate_and_get_worker_args(worker_service, backend):
    """Helper function to validate worker service and get its arguments.

    Args:
        worker_service: Worker service object to validate
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".

    Returns:
        List of arguments from the worker service
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

349
350
    if not worker_service.extraPodSpec or not worker_service.extraPodSpec.mainContainer:
        raise ValueError(
351
            f"Missing extraPodSpec or mainContainer in {backend} decode worker service '{WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name}'"
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        )

    args = worker_service.extraPodSpec.mainContainer.args
    return break_arguments(args)


def set_argument_value(args: list, arg_name: str, value: str):
    """Helper function to set an argument value, adding it if not present."""
    try:
        idx = args.index(arg_name)
        args[idx + 1] = value
    except ValueError:
        args = append_argument(args, [arg_name, value])
    return args


368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
def update_image(config: dict, image: str) -> dict:
    """Update container image for all DGD services (frontend, planner, workers).

    This is a shared utility function used by all backend config modifiers.

    Args:
        config: Configuration dictionary
        image: Container image to set for all services

    Returns:
        Updated configuration dictionary
    """
    cfg = Config.model_validate(config)

    # Update image for all services
    for service_name, service_config in cfg.spec.services.items():
        if service_config.extraPodSpec and service_config.extraPodSpec.mainContainer:
            service_config.extraPodSpec.mainContainer.image = image
            logger.debug(f"Updated image for {service_name} to {image}")

    return cfg.model_dump()