config.py 18.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
import re
18
from typing import Literal, Optional, Protocol
19

20
from pydantic import BaseModel
21

22
23
24
25
from benchmarks.profiler.utils.defaults import (
    DEFAULT_MODEL_NAME,
    DYNAMO_RUN_DEFAULT_PORT,
)
26
27
28
29
30
31
32
33
34
35
36
37
38
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


39
class Container(BaseModel):
40
41
    args: Optional[list[str]] = None
    model_config = {"extra": "allow"}
42
43
44


class PodSpec(BaseModel):
45
46
    mainContainer: Optional[Container] = None
    model_config = {"extra": "allow"}
47
48
49


class ServiceResources(BaseModel):
50
    requests: Optional[dict[str, str]] = None
51
52
53
54
    limits: Optional[dict[str, str]] = None


class Service(BaseModel):
55
56
57
58
    replicas: Optional[int] = None
    resources: Optional[ServiceResources] = None
    extraPodSpec: Optional[PodSpec] = None
    model_config = {"extra": "allow"}
59
60
61
62


class Services(BaseModel):
    Frontend: Service
63
    model_config = {"extra": "allow"}
64
65
66
67
68
69
70
71
72
73
74
75
76


class Spec(BaseModel):
    services: dict[str, Service]


class Metadata(BaseModel):
    name: str


class Config(BaseModel):
    metadata: Metadata
    spec: Spec
77
    model_config = {"extra": "allow"}
78
79


80
81
82
83
def break_arguments(args: list[str] | None) -> list[str]:
    ans: list[str] = []
    if args is None:
        return ans
84
    if isinstance(args, str):
85
        ans = re.split(r"[ =]", args)
86
87
    else:
        for arg in args:
88
89
            if arg is not None:
                ans.extend(arg.split(" "))
90
    return ans
91
92


93
94
95
96
97
98
99
100
101
102
def remove_valued_arguments(args: list[str], key: str) -> list[str]:
    """Remove a valued argument (e.g., --key value) from the arguments list if exists."""
    if key in args:
        idx = args.index(key)
        if idx + 1 < len(args):
            del args[idx : idx + 2]

    return args


103
104
def join_arguments(args: list[str]) -> list[str]:
    return [" ".join(args)]
105
106


107
108
109
110
111
112
113
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
114
115


116
117
118
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
119

120
121
122
123
124
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
125

126
127
128
129
130
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
131

132
    return idx
133
134


135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class ConfigModifierProtocol(Protocol):
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
        ...

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int) -> dict:
        ...

    @classmethod
    def get_model_name(cls, config: dict) -> str:
        ...

    @classmethod
    def get_port(cls, config: dict) -> int:
        ...

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
        ...


157
158
159
class VllmV1ConfigModifier:
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
160
        cfg = Config.model_validate(config)
161

162
        # set metadata name
163
        cfg.metadata.name = "vllm-agg"
164

165
        # disable planner
166
167
        if "Planner" in cfg.spec.services:
            del cfg.spec.services["Planner"]
168
169

        if target == "prefill":
170
            # convert prefill worker into decode worker
171
            cfg.spec.services[
172
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
173
            ] = cfg.spec.services[
174
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
175
            ]
176
            del cfg.spec.services[
177
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
178
179
            ]

180
            worker_service = cfg.spec.services[
181
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
182
183
184
185
186
187
188
189
190
            ]
            if (
                not worker_service.extraPodSpec
                or not worker_service.extraPodSpec.mainContainer
            ):
                raise ValueError(
                    "Missing extraPodSpec or mainContainer in worker service"
                )
            args = worker_service.extraPodSpec.mainContainer.args
191
192
193
194
195
196
197
198
199
200
201
202

            args = break_arguments(args)

            # remove --is-prefill-worker flag
            args.remove("--is-prefill-worker")

            # disable prefix caching
            if "--enable-prefix-caching" in args:
                args.remove("--enable-prefix-caching")
            if "--no-enable-prefix-caching" not in args:
                args = append_argument(args, "--no-enable-prefix-caching")

203
            worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
204

205
        elif target == "decode":
206
            # delete prefill worker
207
            del cfg.spec.services[
208
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
209
210
            ]

211
            worker_service = cfg.spec.services[
212
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
213
214
215
216
217
218
219
220
221
            ]
            if (
                not worker_service.extraPodSpec
                or not worker_service.extraPodSpec.mainContainer
            ):
                raise ValueError(
                    "Missing extraPodSpec or mainContainer in worker service"
                )
            args = worker_service.extraPodSpec.mainContainer.args
222
223

            args = break_arguments(args)
224

225
226
227
228
229
230
            # enable prefix caching
            if "--enable-prefix-caching" not in args:
                args = append_argument(args, "--enable-prefix-caching")
            if "--no-enable-prefix-caching" in args:
                args.remove("--no-enable-prefix-caching")

231
            worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
232
233

        # set num workers to 1
234
        decode_worker_config = cfg.spec.services[
235
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
236
        ]
237
        decode_worker_config.replicas = 1
238

239
        return cfg.model_dump()
240
241
242

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int):
243
        cfg = Config.model_validate(config)
244

245
        worker_service = cfg.spec.services[
246
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        ]

        # Ensure resources exists
        if worker_service.resources is None:
            worker_service.resources = ServiceResources()

        # Ensure requests exists
        if worker_service.resources.requests is None:
            worker_service.resources.requests = {}

        worker_service.resources.requests["gpu"] = str(tp_size)

        # Update limits if they exist
        if worker_service.resources.limits is not None:
            worker_service.resources.limits["gpu"] = str(tp_size)

263
        if (
264
265
            not worker_service.extraPodSpec
            or not worker_service.extraPodSpec.mainContainer
266
        ):
267
268
            raise ValueError("Missing extraPodSpec or mainContainer in worker service")
        args = worker_service.extraPodSpec.mainContainer.args
269
270
271
272
273
274
275
276
277

        args = break_arguments(args)

        try:
            idx = args.index("--tensor-parallel-size")
            args[idx + 1] = str(tp_size)
        except ValueError:
            args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])

278
        worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
279

280
        return cfg.model_dump()
281
282
283

    @classmethod
    def get_model_name(cls, config: dict) -> str:
284
        cfg = Config.model_validate(config)
285
        worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
286
287
288
289
290
291
292
293
294
295
        worker_service = cfg.spec.services[worker_name]
        if (
            not worker_service.extraPodSpec
            or not worker_service.extraPodSpec.mainContainer
        ):
            logger.warning(
                f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
            )
            return DEFAULT_MODEL_NAME
        args = worker_service.extraPodSpec.mainContainer.args
296
297
298
299
300
301
302
303
304
305

        args = break_arguments(args)
        for i, arg in enumerate(args):
            if arg == "--model" and i + 1 < len(args):
                return args[i + 1]

        logger.warning(
            f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
        )
        return DEFAULT_MODEL_NAME
306
307
308

    @classmethod
    def get_port(cls, config: dict) -> int:
309
        cfg = Config.model_validate(config)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
        frontend_service = cfg.spec.services.get("Frontend")
        if (
            not frontend_service
            or not frontend_service.extraPodSpec
            or not frontend_service.extraPodSpec.mainContainer
        ):
            logger.warning(
                f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

        args = frontend_service.extraPodSpec.mainContainer.args
        if not args:
            logger.warning(
                f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

328
329
330
331
        args = break_arguments(args)
        try:
            idx = args.index("--http-port")
            return int(args[idx + 1])
332
        except (ValueError, IndexError):
333
334
335
336
            logger.warning(
                f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT
337
338
339

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
340
        # TODO
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
        try:
            with open(dynamo_log_fn, "r") as f:
                for line in f:
                    if "Maximum concurrency for" in line:
                        line = line.strip().split("Maximum concurrency for ")[1]
                        token_count = int(
                            line.split(" tokens per request: ")[0].replace(",", "")
                        )
                        concurrency = float(line.split(" tokens per request: ")[1][:-1])

                        logger.info(
                            f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
                        )
                        return int(token_count * concurrency)
        except Exception as e:
            logger.warning(
                f"Failed to parse KV cache size from line: {line}. Error: {e}"
            )
        return 0


362
363
364
class SGLangConfigModifier:
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
365
        cfg = Config.model_validate(config)
366
367

        # set metadata name
368
        cfg.metadata.name = "sglang-agg"
369
370

        # disable planner
371
372
        if "Planner" in cfg.spec.services:
            del cfg.spec.services["Planner"]
373
374
375

        if target == "prefill":
            # convert prefill worker into decode worker
376
            cfg.spec.services[
377
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
378
            ] = cfg.spec.services[
379
380
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]
381
            del cfg.spec.services[
382
383
384
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]

385
            worker_service = cfg.spec.services[
386
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
387
388
389
390
391
392
393
394
395
            ]
            if (
                not worker_service.extraPodSpec
                or not worker_service.extraPodSpec.mainContainer
            ):
                raise ValueError(
                    "Missing extraPodSpec or mainContainer in worker service"
                )
            args = worker_service.extraPodSpec.mainContainer.args
396
397
398
399
400
401
402
403
404
405
406

            args = break_arguments(args)

            # remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
            args = remove_valued_arguments(args, "--disaggregation-mode")
            args = remove_valued_arguments(args, "--disaggregation-transfer-backend")

            # disable prefix caching
            if "--disable-radix-cache" not in args:
                args = append_argument(args, "--disable-radix-cache")

407
            worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
408
409
410

        elif target == "decode":
            # delete prefill worker
411
            del cfg.spec.services[
412
413
414
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]

415
            worker_service = cfg.spec.services[
416
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
417
418
419
420
421
422
423
424
425
            ]
            if (
                not worker_service.extraPodSpec
                or not worker_service.extraPodSpec.mainContainer
            ):
                raise ValueError(
                    "Missing extraPodSpec or mainContainer in worker service"
                )
            args = worker_service.extraPodSpec.mainContainer.args
426
427
428
429
430
431
432
433
434
435
436

            args = break_arguments(args)

            # remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
            args = remove_valued_arguments(args, "--disaggregation-mode")
            args = remove_valued_arguments(args, "--disaggregation-transfer-backend")

            # enable prefix caching
            if "--disable-radix-cache" in args:
                args.remove("--disable-radix-cache")

437
            worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
438
439
440
441
442
443
444
445
446
447
448

        # set num workers to 1
        decode_worker_config = config["spec"]["services"][
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
        ]
        decode_worker_config["replicas"] = 1

        return config

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int):
449
        cfg = Config.model_validate(config)
450

451
        worker_service = cfg.spec.services[
452
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
        ]

        # Ensure resources exists
        if worker_service.resources is None:
            worker_service.resources = ServiceResources()

        # Ensure requests exists
        if worker_service.resources.requests is None:
            worker_service.resources.requests = {}

        worker_service.resources.requests["gpu"] = str(tp_size)

        # Update limits if they exist
        if worker_service.resources.limits is not None:
            worker_service.resources.limits["gpu"] = str(tp_size)

469
        if (
470
471
            not worker_service.extraPodSpec
            or not worker_service.extraPodSpec.mainContainer
472
        ):
473
474
            raise ValueError("Missing extraPodSpec or mainContainer in worker service")
        args = worker_service.extraPodSpec.mainContainer.args
475
476
477
478
479
480
481
482
483

        args = break_arguments(args)

        try:
            idx = args.index("--tp")
            args[idx + 1] = str(tp_size)
        except ValueError:
            args = append_argument(args, ["--tp", str(tp_size)])

484
        worker_service.extraPodSpec.mainContainer.args = join_arguments(args)
485

486
        return cfg.model_dump()
487
488
489

    @classmethod
    def get_model_name(cls, config: dict) -> str:
490
        cfg = Config.model_validate(config)
491
        worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
492
493
494
495
496
497
498
499
500
501
        worker_service = cfg.spec.services[worker_name]
        if (
            not worker_service.extraPodSpec
            or not worker_service.extraPodSpec.mainContainer
        ):
            logger.warning(
                f"Worker service missing extraPodSpec or mainContainer, using default model name: {DEFAULT_MODEL_NAME}"
            )
            return DEFAULT_MODEL_NAME
        args = worker_service.extraPodSpec.mainContainer.args
502
503
504
505
506
507
508
509
510
511
512
513
514

        args = break_arguments(args)
        for i, arg in enumerate(args):
            if arg == "--served-model-name" and i + 1 < len(args):
                return args[i + 1]

        logger.warning(
            f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
        )
        return DEFAULT_MODEL_NAME

    @classmethod
    def get_port(cls, config: dict) -> int:
515
        cfg = Config.model_validate(config)
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        frontend_service = cfg.spec.services.get("Frontend")
        if (
            not frontend_service
            or not frontend_service.extraPodSpec
            or not frontend_service.extraPodSpec.mainContainer
        ):
            logger.warning(
                f"Frontend service or container not found, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

        args = frontend_service.extraPodSpec.mainContainer.args
        if not args:
            logger.warning(
                f"No args found in Frontend configuration, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

534
535
536
537
        args = break_arguments(args)
        try:
            idx = args.index("--http-port")
            return int(args[idx + 1])
538
        except (ValueError, IndexError):
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
            logger.warning(
                f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
        # TODO
        try:
            with open(dynamo_log_fn, "r") as f:
                for line in f:
                    if "KV Cache is allocated" in line and "#tokens:" in line:
                        # Extract the number after "#tokens:"
                        match = re.search(r"#tokens:\s*(\d+)", line)
                        if match:
                            return int(match.group(1))
        except Exception as e:
            logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
        return 0


560
CONFIG_MODIFIERS: dict[str, type[ConfigModifierProtocol]] = {
561
    "vllm": VllmV1ConfigModifier,
562
    "sglang": SGLangConfigModifier,
563
}
564
565
566

# Re-export WORKER_COMPONENT_NAMES for profile_sla.py
__all__ = ["CONFIG_MODIFIERS", "WORKER_COMPONENT_NAMES"]