config.py 13.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import json
17
import logging
18
import math
19
import shlex
20
from typing import Literal, Optional, Protocol
21

22
from pydantic import BaseModel
23

24
from dynamo.common.utils.paths import get_workspace_dir
25
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
26
27
28
29
30
31
32
33
34
35
36
37

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


38
39
40
41
42
class VolumeMount(BaseModel):
    name: str = "dynamo-pvc"
    mountPoint: str = "/data"


43
class Container(BaseModel):
44
45
46
    image: Optional[str] = None
    workingDir: Optional[str] = None
    command: Optional[list[str]] = None
47
48
    args: Optional[list[str]] = None
    model_config = {"extra": "allow"}
49
50
51


class PodSpec(BaseModel):
52
53
    mainContainer: Optional[Container] = None
    model_config = {"extra": "allow"}
54
55
56


class ServiceResources(BaseModel):
57
    requests: Optional[dict[str, str]] = None
58
59
60
61
    limits: Optional[dict[str, str]] = None


class Service(BaseModel):
62
63
64
    replicas: Optional[int] = None
    resources: Optional[ServiceResources] = None
    extraPodSpec: Optional[PodSpec] = None
65
    subComponentType: Optional[str] = None
66
    model_config = {"extra": "allow"}
67
68
69
70


class Services(BaseModel):
    Frontend: Service
71
    model_config = {"extra": "allow"}
72
73


74
75
76
77
78
79
class PVCConfig(BaseModel):
    name: str = "dynamo-pvc"
    create: Optional[bool] = False
    model_config = {"extra": "allow"}


80
81
class Spec(BaseModel):
    services: dict[str, Service]
82
83
    pvcs: Optional[list[PVCConfig]] = None
    model_config = {"extra": "allow"}
84
85
86
87


class Metadata(BaseModel):
    name: str
88
    model_config = {"extra": "allow"}
89
90
91
92
93


class Config(BaseModel):
    metadata: Metadata
    spec: Spec
94
    model_config = {"extra": "allow"}
95
96


97
98
99
100
class MultinodeConfig(BaseModel):
    nodeCount: int


101
102
103
104
class DgdPlannerServiceConfig(BaseModel):
    dynamoNamespace: str = "dynamo"  # placeholder
    componentType: str = "planner"
    replicas: int = 1
105
106
    # Do not attach PVC; we'll mount a ConfigMap for planner data instead.
    volumeMounts: list[VolumeMount] = []
107
108
109
    extraPodSpec: PodSpec = PodSpec(
        mainContainer=Container(
            image="my-registry/dynamo-runtime:my-tag",  # placeholder
110
            workingDir=f"{get_workspace_dir()}/components/src/dynamo/planner",
111
112
113
114
115
116
117
            command=["python3", "-m", "planner_sla"],
            args=[],
        )
    )
    model_config = {"extra": "allow"}


118
119
120
121
def break_arguments(args: list[str] | None) -> list[str]:
    ans: list[str] = []
    if args is None:
        return ans
122
    if isinstance(args, str):
123
124
        # Use shlex.split to properly handle quoted arguments and JSON values
        ans = shlex.split(args)
125
126
    else:
        for arg in args:
127
            if arg is not None:
128
129
130
131
132
133
134
135
136
137
138
                # If the arg looks like it might be JSON (starts with { or [) or is already a single token,
                # don't split it further. Only split if it contains spaces AND doesn't look like JSON.
                if (
                    isinstance(arg, str)
                    and (" " in arg or "\t" in arg)
                    and not (arg.strip().startswith(("{", "[")))
                ):
                    # Use shlex.split to properly handle quoted arguments
                    ans.extend(shlex.split(arg))
                else:
                    ans.append(arg)
139
    return ans
140
141


142
143
144
145
146
147
148
149
150
151
def remove_valued_arguments(args: list[str], key: str) -> list[str]:
    """Remove a valued argument (e.g., --key value) from the arguments list if exists."""
    if key in args:
        idx = args.index(key)
        if idx + 1 < len(args):
            del args[idx : idx + 2]

    return args


152
153
154
155
156
157
158
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
159
160


161
162
163
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
164

165
166
167
168
169
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
170

171
172
173
174
175
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
176

177
    return idx
178
179


180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def parse_override_engine_args(args: list[str]) -> tuple[dict, list[str]]:
    """
    Parse and extract --override-engine-args from argument list.

    Returns:
        tuple: (override_dict, modified_args) where override_dict is the parsed JSON
               and modified_args is the args list with --override-engine-args removed
    """
    override_dict = {}
    try:
        idx = args.index("--override-engine-args")
        if idx + 1 < len(args):
            # Parse existing override
            override_dict = json.loads(args[idx + 1])
            # Remove the old override args
            del args[idx : idx + 2]
    except (ValueError, json.JSONDecodeError):
        pass  # No existing override or invalid JSON

    return override_dict, args


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int):
    """Helper function to set multinode configuration based on GPU count and GPUs per node."""
    if gpu_count <= num_gpus_per_node:
        # Single node: remove multinode configuration if present
        if (
            hasattr(worker_service, "multinode")
            and worker_service.multinode is not None
        ):
            worker_service.multinode = None
    else:
        # Multi-node: set nodeCount = math.ceil(gpu_count / num_gpus_per_node)
        node_count = math.ceil(gpu_count / num_gpus_per_node)
        if not hasattr(worker_service, "multinode") or worker_service.multinode is None:
            # Create multinode configuration if it doesn't exist
            worker_service.multinode = MultinodeConfig(nodeCount=node_count)
        else:
            # Handle both dict (from YAML) and MultinodeConfig object cases
            if isinstance(worker_service.multinode, dict):
                worker_service.multinode["nodeCount"] = node_count
            else:
                worker_service.multinode.nodeCount = node_count


225
def get_service_name_by_type(
226
    config: Config, backend: str, sub_component_type: SubComponentType
227
228
229
230
231
232
) -> str:
    """Helper function to get service name by subComponentType.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
233
        config: Configuration object
234
235
236
237
238
239
240
        backend: Backend name (e.g., "sglang", "vllm", "trtllm")
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE)

    Returns:
        The service name
    """
    # Check if config has the expected structure
241
    if not config.spec or not config.spec.services:
242
243
244
245
246
247
248
        # Fall back to default name if structure is unexpected
        if sub_component_type == SubComponentType.DECODE:
            return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
        else:
            return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Look through services to find one with matching subComponentType
249
    services = config.spec.services
250
    for service_name, service_config in services.items():
251
        if service_config.subComponentType == sub_component_type.value:
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
            return service_name

    # Fall back to default component names
    if sub_component_type == SubComponentType.DECODE:
        default_name = WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
    else:
        default_name = WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Check if the default name exists in services
    if default_name in services:
        return default_name

    # Last resort: return the default name anyway
    return default_name


def get_worker_service_from_config(
269
    config: Config,
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
    backend: str = "sglang",
    sub_component_type: SubComponentType = SubComponentType.DECODE,
):
    """Helper function to get a worker service from config.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
        config: Configuration dictionary
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE). Defaults to DECODE.

    Returns:
        The worker service from the configuration
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

    # Get the service name using the type-aware logic
    service_name = get_service_name_by_type(config, backend, sub_component_type)

    # Get the actual service from the config
294
    return config.spec.services[service_name]
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325


def setup_worker_service_resources(
    worker_service, gpu_count: int, num_gpus_per_node: Optional[int] = None
):
    """Helper function to set up worker service resources (requests and limits)."""
    # Handle multinode configuration if num_gpus_per_node is provided
    if num_gpus_per_node is not None:
        set_multinode_config(worker_service, gpu_count, num_gpus_per_node)

    # Ensure resources exists
    if worker_service.resources is None:
        worker_service.resources = ServiceResources()

    # Ensure requests exists
    if worker_service.resources.requests is None:
        worker_service.resources.requests = {}

    # Set GPU requests
    gpu_value = (
        min(gpu_count, num_gpus_per_node)
        if num_gpus_per_node is not None
        else gpu_count
    )
    worker_service.resources.requests["gpu"] = str(gpu_value)

    # Update limits if they exist
    if worker_service.resources.limits is not None:
        worker_service.resources.limits["gpu"] = str(gpu_value)


326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def validate_and_get_worker_args(worker_service, backend):
    """Helper function to validate worker service and get its arguments.

    Args:
        worker_service: Worker service object to validate
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".

    Returns:
        List of arguments from the worker service
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

341
342
    if not worker_service.extraPodSpec or not worker_service.extraPodSpec.mainContainer:
        raise ValueError(
343
            f"Missing extraPodSpec or mainContainer in {backend} decode worker service '{WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name}'"
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
        )

    args = worker_service.extraPodSpec.mainContainer.args
    return break_arguments(args)


def set_argument_value(args: list, arg_name: str, value: str):
    """Helper function to set an argument value, adding it if not present."""
    try:
        idx = args.index(arg_name)
        args[idx + 1] = value
    except ValueError:
        args = append_argument(args, [arg_name, value])
    return args


360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
def update_image(config: dict, image: str) -> dict:
    """Update container image for all DGD services (frontend, planner, workers).

    This is a shared utility function used by all backend config modifiers.

    Args:
        config: Configuration dictionary
        image: Container image to set for all services

    Returns:
        Updated configuration dictionary
    """
    cfg = Config.model_validate(config)

    # Update image for all services
    for service_name, service_config in cfg.spec.services.items():
        if service_config.extraPodSpec and service_config.extraPodSpec.mainContainer:
            service_config.extraPodSpec.mainContainer.image = image
            logger.debug(f"Updated image for {service_name} to {image}")

    return cfg.model_dump()


383
384
class ConfigModifierProtocol(Protocol):
    @classmethod
385
386
387
388
389
390
    def convert_config(
        cls,
        config: dict,
        target: Literal["prefill", "decode"],
        is_moe_model: bool = False,
    ) -> dict:
391
392
393
        ...

    @classmethod
394
395
396
397
398
399
    def set_config_tp_size(
        cls,
        config: dict,
        tp_size: int,
        component_type: SubComponentType = SubComponentType.DECODE,
    ) -> dict:
400
401
        ...

402
403
    @classmethod
    def set_config_tep_size(
404
405
406
407
408
        cls,
        config: dict,
        tep_size: int,
        num_gpus_per_node: int,
        component_type: SubComponentType = SubComponentType.DECODE,
409
410
411
412
413
    ) -> dict:
        ...

    @classmethod
    def set_config_dep_size(
414
415
416
417
418
        cls,
        config: dict,
        dep_size: int,
        num_gpus_per_node: int,
        component_type: SubComponentType = SubComponentType.DECODE,
419
420
421
    ) -> dict:
        ...

422
423
424
425
426
427
428
429
430
    @classmethod
    def get_model_name(cls, config: dict) -> str:
        ...

    @classmethod
    def get_port(cls, config: dict) -> int:
        ...

    @classmethod
431
432
433
    def get_kv_cache_size_from_dynamo_log(
        cls, dynamo_log_fn: str, attention_dp_size: int = 1
    ) -> int:
434
435
        ...

436
437
438
439
440
441
442
443
    @classmethod
    def load_default_config(cls) -> dict:
        ...

    @classmethod
    def update_model(cls, config: dict, model_name: str) -> dict:
        ...

444
445
446
    @classmethod
    def update_image(cls, config: dict, image: str) -> dict:
        ...