config.py 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
import re
18
from typing import Literal, Optional, cast
19

20
from pydantic import BaseModel
21
22
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT

23
24
25
26
27
28
29
30
31
32
33
34
35
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class Container(BaseModel):
    args: list[str] = []


class PodSpec(BaseModel):
    mainContainer: Container


class ServiceResources(BaseModel):
    requests: dict[str, str]
    limits: Optional[dict[str, str]] = None


class Service(BaseModel):
    replicas: int
    resources: ServiceResources
    extraPodSpec: PodSpec


class Services(BaseModel):
    Frontend: Service
    __root__: dict[str, Service]


class Spec(BaseModel):
    services: dict[str, Service]


class Metadata(BaseModel):
    name: str


class Config(BaseModel):
    metadata: Metadata
    spec: Spec


73
74
75
def break_arguments(args: list[str]) -> list[str]:
    ans = []
    if isinstance(args, str):
76
        ans = re.split(r"[ =]", args)
77
78
79
80
    else:
        for arg in args:
            ans.extend(arg.split(" "))
    return ans
81
82


83
84
85
86
87
88
89
90
91
92
def remove_valued_arguments(args: list[str], key: str) -> list[str]:
    """Remove a valued argument (e.g., --key value) from the arguments list if exists."""
    if key in args:
        idx = args.index(key)
        if idx + 1 < len(args):
            del args[idx : idx + 2]

    return args


93
94
def join_arguments(args: list[str]) -> list[str]:
    return [" ".join(args)]
95
96


97
98
99
100
101
102
103
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
104
105


106
107
108
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
109

110
111
112
113
114
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
115

116
117
118
119
120
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
121

122
    return idx
123
124
125
126
127


class VllmV1ConfigModifier:
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
128
        cfg = Config.model_validate(config)
129

130
        # set metadata name
131
        cfg.metadata.name = "vllm-agg"
132

133
        # disable planner
134
135
        if "Planner" in cfg.spec.services:
            del cfg.spec.services["Planner"]
136
137

        if target == "prefill":
138
            # convert prefill worker into decode worker
139
            cfg.spec.services[
140
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
141
            ] = cfg.spec.services[
142
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
143
            ]
144
            del cfg.spec.services[
145
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
146
147
            ]

148
            args = cfg.spec.services[
149
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
150
            ].extraPodSpec.mainContainer.args
151
152
153
154
155
156
157
158
159
160
161
162

            args = break_arguments(args)

            # remove --is-prefill-worker flag
            args.remove("--is-prefill-worker")

            # disable prefix caching
            if "--enable-prefix-caching" in args:
                args.remove("--enable-prefix-caching")
            if "--no-enable-prefix-caching" not in args:
                args = append_argument(args, "--no-enable-prefix-caching")

163
            cfg.spec.services[
164
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
165
            ].extraPodSpec.mainContainer.args = join_arguments(args)
166

167
        elif target == "decode":
168
            # delete prefill worker
169
            del cfg.spec.services[
170
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
171
172
            ]

173
            args = cfg.spec.services[
174
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
175
            ].extraPodSpec.mainContainer.args
176
177

            args = break_arguments(args)
178

179
180
181
182
183
184
            # enable prefix caching
            if "--enable-prefix-caching" not in args:
                args = append_argument(args, "--enable-prefix-caching")
            if "--no-enable-prefix-caching" in args:
                args.remove("--no-enable-prefix-caching")

185
            cfg.spec.services[
186
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
187
            ].extraPodSpec.mainContainer.args = join_arguments(args)
188
189

        # set num workers to 1
190
        decode_worker_config = cfg.spec.services[
191
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
192
        ]
193
        decode_worker_config.replicas = 1
194

195
        return cfg.model_dump()
196
197
198

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int):
199
        cfg = Config.model_validate(config)
200

201
        cfg.spec.services[
202
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
203
        ].resources.requests["gpu"] = str(tp_size)
204
        if (
205
            cfg.spec.services[
206
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
207
208
            ].resources.limits
            is not None
209
        ):
210
211
212
213
214
215
216
217
218
219
            # Explicitly cast `limits` as the typecheck cannot determine that
            # limits is not None here
            cast(
                dict[str, str],
                cfg.spec.services[
                    WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
                ].resources.limits,
            )["gpu"] = str(tp_size)

        args = cfg.spec.services[
220
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
221
        ].extraPodSpec.mainContainer.args
222
223
224
225
226
227
228
229
230

        args = break_arguments(args)

        try:
            idx = args.index("--tensor-parallel-size")
            args[idx + 1] = str(tp_size)
        except ValueError:
            args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])

231
        cfg.spec.services[
232
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
233
        ].extraPodSpec.mainContainer.args = join_arguments(args)
234

235
        return cfg.model_dump()
236
237
238

    @classmethod
    def get_model_name(cls, config: dict) -> str:
239
        cfg = Config.model_validate(config)
240
        worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
241
        args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
242
243
244
245
246
247
248
249
250
251

        args = break_arguments(args)
        for i, arg in enumerate(args):
            if arg == "--model" and i + 1 < len(args):
                return args[i + 1]

        logger.warning(
            f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
        )
        return DEFAULT_MODEL_NAME
252
253
254

    @classmethod
    def get_port(cls, config: dict) -> int:
255
256
        cfg = Config.model_validate(config)
        args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
257
258
259
260
261
262
263
264
265
        args = break_arguments(args)
        try:
            idx = args.index("--http-port")
            return int(args[idx + 1])
        except ValueError:
            logger.warning(
                f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT
266
267
268

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
269
        # TODO
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        try:
            with open(dynamo_log_fn, "r") as f:
                for line in f:
                    if "Maximum concurrency for" in line:
                        line = line.strip().split("Maximum concurrency for ")[1]
                        token_count = int(
                            line.split(" tokens per request: ")[0].replace(",", "")
                        )
                        concurrency = float(line.split(" tokens per request: ")[1][:-1])

                        logger.info(
                            f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
                        )
                        return int(token_count * concurrency)
        except Exception as e:
            logger.warning(
                f"Failed to parse KV cache size from line: {line}. Error: {e}"
            )
        return 0


291
292
293
class SGLangConfigModifier:
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
294
        cfg = Config.model_validate(config)
295
296

        # set metadata name
297
        cfg.metadata.name = "sglang-agg"
298
299

        # disable planner
300
301
        if "Planner" in cfg.spec.services:
            del cfg.spec.services["Planner"]
302
303
304

        if target == "prefill":
            # convert prefill worker into decode worker
305
            cfg.spec.services[
306
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
307
            ] = cfg.spec.services[
308
309
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]
310
            del cfg.spec.services[
311
312
313
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]

314
            args = cfg.spec.services[
315
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
316
            ].extraPodSpec.mainContainer.args
317
318
319
320
321
322
323
324
325
326
327

            args = break_arguments(args)

            # remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
            args = remove_valued_arguments(args, "--disaggregation-mode")
            args = remove_valued_arguments(args, "--disaggregation-transfer-backend")

            # disable prefix caching
            if "--disable-radix-cache" not in args:
                args = append_argument(args, "--disable-radix-cache")

328
            cfg.spec.services[
329
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
330
            ].extraPodSpec.mainContainer.args = join_arguments(args)
331
332
333

        elif target == "decode":
            # delete prefill worker
334
            del cfg.spec.services[
335
336
337
                WORKER_COMPONENT_NAMES["sglang"].prefill_worker_k8s_name
            ]

338
            args = cfg.spec.services[
339
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
340
            ].extraPodSpec.mainContainer.args
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

            args = break_arguments(args)

            # call `dynamo.sglang.worker` instead of `dynamo.sglang.decode_worker`
            idx = args.index("dynamo.sglang.decode_worker")
            args[idx] = "dynamo.sglang.worker"

            # remove `--disaggregation-mode` and `--disaggregation-transfer-backend`
            args = remove_valued_arguments(args, "--disaggregation-mode")
            args = remove_valued_arguments(args, "--disaggregation-transfer-backend")

            # enable prefix caching
            if "--disable-radix-cache" in args:
                args.remove("--disable-radix-cache")

356
            cfg.spec.services[
357
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
358
            ].extraPodSpec.mainContainer.args = join_arguments(args)
359
360
361
362
363
364
365
366
367
368
369

        # set num workers to 1
        decode_worker_config = config["spec"]["services"][
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
        ]
        decode_worker_config["replicas"] = 1

        return config

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int):
370
        cfg = Config.model_validate(config)
371

372
        cfg.spec.services[
373
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
374
        ].resources.requests["gpu"] = str(tp_size)
375
        if (
376
            cfg.spec.services[
377
                WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
378
379
            ].resources.limits
            is not None
380
        ):
381
382
383
384
385
386
387
388
389
390
            # Explicitly cast `limits` as the typecheck cannot determine that
            # limits is not None here
            cast(
                dict[str, str],
                cfg.spec.services[
                    WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
                ].resources.limits,
            )["gpu"] = str(tp_size)

        args = cfg.spec.services[
391
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
392
        ].extraPodSpec.mainContainer.args
393
394
395
396
397
398
399
400
401

        args = break_arguments(args)

        try:
            idx = args.index("--tp")
            args[idx + 1] = str(tp_size)
        except ValueError:
            args = append_argument(args, ["--tp", str(tp_size)])

402
        cfg.spec.services[
403
            WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
404
        ].extraPodSpec.mainContainer.args = join_arguments(args)
405

406
        return cfg.model_dump()
407
408
409

    @classmethod
    def get_model_name(cls, config: dict) -> str:
410
        cfg = Config.model_validate(config)
411
        worker_name = WORKER_COMPONENT_NAMES["sglang"].decode_worker_k8s_name
412
        args = cfg.spec.services[worker_name].extraPodSpec.mainContainer.args
413
414
415
416
417
418
419
420
421
422
423
424
425

        args = break_arguments(args)
        for i, arg in enumerate(args):
            if arg == "--served-model-name" and i + 1 < len(args):
                return args[i + 1]

        logger.warning(
            f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
        )
        return DEFAULT_MODEL_NAME

    @classmethod
    def get_port(cls, config: dict) -> int:
426
427
        cfg = Config.model_validate(config)
        args = cfg.spec.services["Frontend"].extraPodSpec.mainContainer.args
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        args = break_arguments(args)
        try:
            idx = args.index("--http-port")
            return int(args[idx + 1])
        except ValueError:
            logger.warning(
                f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
        # TODO
        try:
            with open(dynamo_log_fn, "r") as f:
                for line in f:
                    if "KV Cache is allocated" in line and "#tokens:" in line:
                        # Extract the number after "#tokens:"
                        match = re.search(r"#tokens:\s*(\d+)", line)
                        if match:
                            return int(match.group(1))
        except Exception as e:
            logger.warning(f"Failed to parse KV cache size from log file. Error: {e}")
        return 0


454
CONFIG_MODIFIERS = {
455
    "vllm": VllmV1ConfigModifier,
456
    "sglang": SGLangConfigModifier,
457
}