config.py 18.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import json
17
import logging
18
import math
19
import shlex
20
from typing import Literal, Optional, Protocol
21

22
import yaml
23
from pydantic import BaseModel
24

25
from benchmarks.profiler.utils.planner_utils import build_planner_args_from_namespace
26
from dynamo.common.utils.paths import get_workspace_dir
27
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SubComponentType
28
29
30
31
32
33
34
35
36
37
38
39

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


40
41
42
43
44
class VolumeMount(BaseModel):
    name: str = "dynamo-pvc"
    mountPoint: str = "/data"


45
class Container(BaseModel):
46
47
48
    image: Optional[str] = None
    workingDir: Optional[str] = None
    command: Optional[list[str]] = None
49
50
    args: Optional[list[str]] = None
    model_config = {"extra": "allow"}
51
52
53


class PodSpec(BaseModel):
54
55
    mainContainer: Optional[Container] = None
    model_config = {"extra": "allow"}
56
57
58


class ServiceResources(BaseModel):
59
    requests: Optional[dict[str, str]] = None
60
61
62
63
    limits: Optional[dict[str, str]] = None


class Service(BaseModel):
64
65
66
    replicas: Optional[int] = None
    resources: Optional[ServiceResources] = None
    extraPodSpec: Optional[PodSpec] = None
67
    subComponentType: Optional[str] = None
68
    model_config = {"extra": "allow"}
69
70
71
72


class Services(BaseModel):
    Frontend: Service
73
    model_config = {"extra": "allow"}
74
75


76
77
78
79
80
81
class PVCConfig(BaseModel):
    name: str = "dynamo-pvc"
    create: Optional[bool] = False
    model_config = {"extra": "allow"}


82
83
class Spec(BaseModel):
    services: dict[str, Service]
84
85
    pvcs: Optional[list[PVCConfig]] = None
    model_config = {"extra": "allow"}
86
87
88
89


class Metadata(BaseModel):
    name: str
90
    model_config = {"extra": "allow"}
91
92
93
94
95


class Config(BaseModel):
    metadata: Metadata
    spec: Spec
96
    model_config = {"extra": "allow"}
97
98


99
100
101
102
class MultinodeConfig(BaseModel):
    nodeCount: int


103
104
105
106
107
108
109
110
class DgdPlannerServiceConfig(BaseModel):
    dynamoNamespace: str = "dynamo"  # placeholder
    componentType: str = "planner"
    replicas: int = 1
    volumeMounts: list[VolumeMount] = [VolumeMount()]
    extraPodSpec: PodSpec = PodSpec(
        mainContainer=Container(
            image="my-registry/dynamo-runtime:my-tag",  # placeholder
111
            workingDir=f"{get_workspace_dir()}/components/src/dynamo/planner",
112
113
114
115
116
117
118
            command=["python3", "-m", "planner_sla"],
            args=[],
        )
    )
    model_config = {"extra": "allow"}


119
120
121
122
def break_arguments(args: list[str] | None) -> list[str]:
    ans: list[str] = []
    if args is None:
        return ans
123
    if isinstance(args, str):
124
125
        # Use shlex.split to properly handle quoted arguments and JSON values
        ans = shlex.split(args)
126
127
    else:
        for arg in args:
128
            if arg is not None:
129
130
131
132
133
134
135
136
137
138
139
                # If the arg looks like it might be JSON (starts with { or [) or is already a single token,
                # don't split it further. Only split if it contains spaces AND doesn't look like JSON.
                if (
                    isinstance(arg, str)
                    and (" " in arg or "\t" in arg)
                    and not (arg.strip().startswith(("{", "[")))
                ):
                    # Use shlex.split to properly handle quoted arguments
                    ans.extend(shlex.split(arg))
                else:
                    ans.append(arg)
140
    return ans
141
142


143
144
145
146
147
148
149
150
151
152
def remove_valued_arguments(args: list[str], key: str) -> list[str]:
    """Remove a valued argument (e.g., --key value) from the arguments list if exists."""
    if key in args:
        idx = args.index(key)
        if idx + 1 < len(args):
            del args[idx : idx + 2]

    return args


153
154
155
156
157
158
159
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
160
161


162
163
164
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
165

166
167
168
169
170
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
171

172
173
174
175
176
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
177

178
    return idx
179
180


181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def parse_override_engine_args(args: list[str]) -> tuple[dict, list[str]]:
    """
    Parse and extract --override-engine-args from argument list.

    Returns:
        tuple: (override_dict, modified_args) where override_dict is the parsed JSON
               and modified_args is the args list with --override-engine-args removed
    """
    override_dict = {}
    try:
        idx = args.index("--override-engine-args")
        if idx + 1 < len(args):
            # Parse existing override
            override_dict = json.loads(args[idx + 1])
            # Remove the old override args
            del args[idx : idx + 2]
    except (ValueError, json.JSONDecodeError):
        pass  # No existing override or invalid JSON

    return override_dict, args


203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def set_multinode_config(worker_service, gpu_count: int, num_gpus_per_node: int):
    """Helper function to set multinode configuration based on GPU count and GPUs per node."""
    if gpu_count <= num_gpus_per_node:
        # Single node: remove multinode configuration if present
        if (
            hasattr(worker_service, "multinode")
            and worker_service.multinode is not None
        ):
            worker_service.multinode = None
    else:
        # Multi-node: set nodeCount = math.ceil(gpu_count / num_gpus_per_node)
        node_count = math.ceil(gpu_count / num_gpus_per_node)
        if not hasattr(worker_service, "multinode") or worker_service.multinode is None:
            # Create multinode configuration if it doesn't exist
            worker_service.multinode = MultinodeConfig(nodeCount=node_count)
        else:
            # Handle both dict (from YAML) and MultinodeConfig object cases
            if isinstance(worker_service.multinode, dict):
                worker_service.multinode["nodeCount"] = node_count
            else:
                worker_service.multinode.nodeCount = node_count


226
def get_service_name_by_type(
227
    config: Config, backend: str, sub_component_type: SubComponentType
228
229
230
231
232
233
) -> str:
    """Helper function to get service name by subComponentType.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
234
        config: Configuration object
235
236
237
238
239
240
241
        backend: Backend name (e.g., "sglang", "vllm", "trtllm")
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE)

    Returns:
        The service name
    """
    # Check if config has the expected structure
242
    if not config.spec or not config.spec.services:
243
244
245
246
247
248
249
        # Fall back to default name if structure is unexpected
        if sub_component_type == SubComponentType.DECODE:
            return WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
        else:
            return WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Look through services to find one with matching subComponentType
250
    services = config.spec.services
251
    for service_name, service_config in services.items():
252
        if service_config.subComponentType == sub_component_type.value:
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
            return service_name

    # Fall back to default component names
    if sub_component_type == SubComponentType.DECODE:
        default_name = WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name
    else:
        default_name = WORKER_COMPONENT_NAMES[backend].prefill_worker_k8s_name

    # Check if the default name exists in services
    if default_name in services:
        return default_name

    # Last resort: return the default name anyway
    return default_name


def get_worker_service_from_config(
270
    config: Config,
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    backend: str = "sglang",
    sub_component_type: SubComponentType = SubComponentType.DECODE,
):
    """Helper function to get a worker service from config.

    First tries to find service by subComponentType, then falls back to component name.

    Args:
        config: Configuration dictionary
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".
        sub_component_type: The type of sub-component to look for (PREFILL or DECODE). Defaults to DECODE.

    Returns:
        The worker service from the configuration
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

    # Get the service name using the type-aware logic
    service_name = get_service_name_by_type(config, backend, sub_component_type)

    # Get the actual service from the config
295
    return config.spec.services[service_name]
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326


def setup_worker_service_resources(
    worker_service, gpu_count: int, num_gpus_per_node: Optional[int] = None
):
    """Helper function to set up worker service resources (requests and limits)."""
    # Handle multinode configuration if num_gpus_per_node is provided
    if num_gpus_per_node is not None:
        set_multinode_config(worker_service, gpu_count, num_gpus_per_node)

    # Ensure resources exists
    if worker_service.resources is None:
        worker_service.resources = ServiceResources()

    # Ensure requests exists
    if worker_service.resources.requests is None:
        worker_service.resources.requests = {}

    # Set GPU requests
    gpu_value = (
        min(gpu_count, num_gpus_per_node)
        if num_gpus_per_node is not None
        else gpu_count
    )
    worker_service.resources.requests["gpu"] = str(gpu_value)

    # Update limits if they exist
    if worker_service.resources.limits is not None:
        worker_service.resources.limits["gpu"] = str(gpu_value)


327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
def validate_and_get_worker_args(worker_service, backend):
    """Helper function to validate worker service and get its arguments.

    Args:
        worker_service: Worker service object to validate
        backend: Backend name (e.g., "sglang", "vllm", "trtllm"). Defaults to "sglang".

    Returns:
        List of arguments from the worker service
    """
    if backend not in WORKER_COMPONENT_NAMES:
        raise ValueError(
            f"Unsupported backend: {backend}. Supported backends: {list(WORKER_COMPONENT_NAMES.keys())}"
        )

342
343
    if not worker_service.extraPodSpec or not worker_service.extraPodSpec.mainContainer:
        raise ValueError(
344
            f"Missing extraPodSpec or mainContainer in {backend} decode worker service '{WORKER_COMPONENT_NAMES[backend].decode_worker_k8s_name}'"
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
        )

    args = worker_service.extraPodSpec.mainContainer.args
    return break_arguments(args)


def set_argument_value(args: list, arg_name: str, value: str):
    """Helper function to set an argument value, adding it if not present."""
    try:
        idx = args.index(arg_name)
        args[idx + 1] = value
    except ValueError:
        args = append_argument(args, [arg_name, value])
    return args


361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
def update_image(config: dict, image: str) -> dict:
    """Update container image for all DGD services (frontend, planner, workers).

    This is a shared utility function used by all backend config modifiers.

    Args:
        config: Configuration dictionary
        image: Container image to set for all services

    Returns:
        Updated configuration dictionary
    """
    cfg = Config.model_validate(config)

    # Update image for all services
    for service_name, service_config in cfg.spec.services.items():
        if service_config.extraPodSpec and service_config.extraPodSpec.mainContainer:
            service_config.extraPodSpec.mainContainer.image = image
            logger.debug(f"Updated image for {service_name} to {image}")

    return cfg.model_dump()


384
385
class ConfigModifierProtocol(Protocol):
    @classmethod
386
387
388
389
390
391
    def convert_config(
        cls,
        config: dict,
        target: Literal["prefill", "decode"],
        is_moe_model: bool = False,
    ) -> dict:
392
393
394
        ...

    @classmethod
395
396
397
398
399
400
    def set_config_tp_size(
        cls,
        config: dict,
        tp_size: int,
        component_type: SubComponentType = SubComponentType.DECODE,
    ) -> dict:
401
402
        ...

403
404
    @classmethod
    def set_config_tep_size(
405
406
407
408
409
        cls,
        config: dict,
        tep_size: int,
        num_gpus_per_node: int,
        component_type: SubComponentType = SubComponentType.DECODE,
410
411
412
413
414
    ) -> dict:
        ...

    @classmethod
    def set_config_dep_size(
415
416
417
418
419
        cls,
        config: dict,
        dep_size: int,
        num_gpus_per_node: int,
        component_type: SubComponentType = SubComponentType.DECODE,
420
421
422
    ) -> dict:
        ...

423
424
425
426
427
428
429
430
431
    @classmethod
    def get_model_name(cls, config: dict) -> str:
        ...

    @classmethod
    def get_port(cls, config: dict) -> int:
        ...

    @classmethod
432
433
434
    def get_kv_cache_size_from_dynamo_log(
        cls, dynamo_log_fn: str, attention_dp_size: int = 1
    ) -> int:
435
436
        ...

437
438
439
440
441
442
443
444
    @classmethod
    def load_default_config(cls) -> dict:
        ...

    @classmethod
    def update_model(cls, config: dict, model_name: str) -> dict:
        ...

445
446
447
448
    @classmethod
    def update_image(cls, config: dict, image: str) -> dict:
        ...

449

450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def generate_dgd_config_with_planner(
    config_path: str,
    config_modifier,
    best_prefill_gpus: int,
    best_decode_gpus: int,
    output_dir: str,
    args,
    is_moe_model: bool = False,
    num_gpus_per_node: int = 8,
):
    """Generate DGD config with planner based on profiling results.

    Args:
        config_path: Path to the YAML config file
        config_modifier: Config modifier instance (e.g., SGLangConfigModifier)
        best_prefill_gpus: Number of GPUs for prefill engine
        best_decode_gpus: Number of GPUs for decode engine
        output_dir: Output directory for profile results
        args: Parsed arguments namespace from profile_sla
        is_moe_model: Whether this is an MoE model
        num_gpus_per_node: Number of GPUs per node (for MoE models)

    Returns:
        dict: Final DGD config with planner service configured
    """

    # Load config from file
    with open(config_path, "r") as f:
        config = yaml.safe_load(f)

480
481
482
483
484
485
486
487
488
    # Update model name in config from profiling args
    # This ensures the final DGD uses the model specified in the DGDR, not the default in the config file
    config = config_modifier.update_model(config, args.model)

    # Update container image if provided
    # This overrides the default image in the config file for all DGD components
    if args.dgd_image:
        config = config_modifier.update_image(config, args.dgd_image)

489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    if not is_moe_model:
        # dense model, use TP for both prefill and decode
        config = config_modifier.set_config_tp_size(
            config, best_prefill_gpus, SubComponentType.PREFILL
        )
        config = config_modifier.set_config_tp_size(
            config, best_decode_gpus, SubComponentType.DECODE
        )
    else:
        # MoE model, use TEP for prefill and DEP for decode
        config = config_modifier.set_config_tep_size(
            config,
            best_prefill_gpus,
            num_gpus_per_node,
            SubComponentType.PREFILL,
        )
        config = config_modifier.set_config_dep_size(
            config,
            best_decode_gpus,
            num_gpus_per_node,
            SubComponentType.DECODE,
        )
    config = Config.model_validate(config)

    # add PVC config if not present
    if not config.spec.pvcs:
        config.spec.pvcs = [PVCConfig()]

    # add the planner service
    planner_config = DgdPlannerServiceConfig()
    frontend_service = config.spec.services["Frontend"]
    planner_config.dynamoNamespace = getattr(frontend_service, "dynamoNamespace", "dynamo")  # type: ignore[attr-defined]
    if frontend_service.extraPodSpec and frontend_service.extraPodSpec.mainContainer:
        frontend_image = frontend_service.extraPodSpec.mainContainer.image
        if frontend_image and planner_config.extraPodSpec.mainContainer:
            planner_config.extraPodSpec.mainContainer.image = frontend_image

    # Build planner args dynamically from parsed arguments
    # This includes shared args (ttft, itl, backend, namespace) from profile_sla
    # and planner-specific args (with planner_ prefix)
    planner_args = build_planner_args_from_namespace(args, prefix="planner_")

    # Override profiling-specific arguments with results from profiling
    # Remove and re-add to ensure correct values from profiling context
    planner_args = [
        arg
        for arg in planner_args
        if not any(
            arg.startswith(f"--{key}=")
            for key in [
                "namespace",
                "prefill-engine-num-gpu",
                "decode-engine-num-gpu",
                "profile-results-dir",
            ]
        )
    ]

    # Add arguments determined by profiling results
    frontend_namespace = getattr(config.spec.services["Frontend"], "dynamoNamespace", "dynamo")  # type: ignore[attr-defined]
    planner_args.extend(
        [
            f"--namespace={frontend_namespace}",
            f"--prefill-engine-num-gpu={best_prefill_gpus}",
            f"--decode-engine-num-gpu={best_decode_gpus}",
            f"--profile-results-dir={output_dir}",
        ]
    )

    if (
        planner_config.extraPodSpec.mainContainer
        and planner_config.extraPodSpec.mainContainer.args is not None
    ):
        planner_config.extraPodSpec.mainContainer.args.extend(planner_args)
    # Convert planner config to dict first, then the entire config to dict
    planner_dict = planner_config.model_dump(exclude_unset=False)
    config_dict = config.model_dump(exclude_unset=False)
    config_dict["spec"]["services"]["Planner"] = planner_dict

    return config_dict