"deploy/snapshot/internal/runtime/process.go" did not exist on "24523a1c297f33ded512127c990b0b7bf2251bf2"
config.py 7.86 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
17
from copy import deepcopy
18
19
from typing import Literal

20
21
from utils.defaults import DEFAULT_MODEL_NAME, DYNAMO_RUN_DEFAULT_PORT

22
23
24
25
26
27
28
29
30
31
32
33
34
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


35
36
37
38
39
40
41
42
def break_arguments(args: list[str]) -> list[str]:
    ans = []
    if isinstance(args, str):
        ans = args.split(" ")
    else:
        for arg in args:
            ans.extend(arg.split(" "))
    return ans
43
44


45
46
def join_arguments(args: list[str]) -> list[str]:
    return [" ".join(args)]
47
48


49
50
51
52
53
54
55
def append_argument(args: list[str], to_append) -> list[str]:
    idx = find_arg_index(args)
    if isinstance(to_append, list):
        args[idx:idx] = to_append
    else:
        args.insert(idx, to_append)
    return args
56
57


58
59
60
def find_arg_index(args: list[str]) -> int:
    # find the correct index to insert an argument
    idx = len(args)
61

62
63
64
65
66
    try:
        new_idx = args.index("|")
        idx = min(idx, new_idx)
    except ValueError:
        pass
67

68
69
70
71
72
    try:
        new_idx = args.index("2>&1")
        idx = min(idx, new_idx)
    except ValueError:
        pass
73

74
    return idx
75
76
77
78
79


class VllmV1ConfigModifier:
    @classmethod
    def convert_config(cls, config: dict, target: Literal["prefill", "decode"]) -> dict:
80
        config = deepcopy(config)
81

82
        # set metadata name
83
        config["metadata"]["name"] = "vllm-agg"
84

85
86
87
        # disable planner
        if "Planner" in config["spec"]["services"]:
            del config["spec"]["services"]["Planner"]
88
89

        if target == "prefill":
90
91
            # convert prefill worker into decode worker
            config["spec"]["services"][
92
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
93
            ] = config["spec"]["services"][
94
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
95
96
            ]
            del config["spec"]["services"][
97
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
98
99
100
            ]

            args = config["spec"]["services"][
101
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
102
103
104
105
106
107
108
109
110
111
112
113
114
            ]["extraPodSpec"]["mainContainer"]["args"]

            args = break_arguments(args)

            # remove --is-prefill-worker flag
            args.remove("--is-prefill-worker")

            # disable prefix caching
            if "--enable-prefix-caching" in args:
                args.remove("--enable-prefix-caching")
            if "--no-enable-prefix-caching" not in args:
                args = append_argument(args, "--no-enable-prefix-caching")

115
116
117
            config["spec"]["services"][
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
            ]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
118

119
        elif target == "decode":
120
121
            # delete prefill worker
            del config["spec"]["services"][
122
                WORKER_COMPONENT_NAMES["vllm"].prefill_worker_k8s_name
123
124
125
            ]

            args = config["spec"]["services"][
126
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
127
128
129
            ]["extraPodSpec"]["mainContainer"]["args"]

            args = break_arguments(args)
130

131
132
133
134
135
136
            # enable prefix caching
            if "--enable-prefix-caching" not in args:
                args = append_argument(args, "--enable-prefix-caching")
            if "--no-enable-prefix-caching" in args:
                args.remove("--no-enable-prefix-caching")

137
138
139
            config["spec"]["services"][
                WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
            ]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
140
141

        # set num workers to 1
142
        decode_worker_config = config["spec"]["services"][
143
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
144
145
        ]
        decode_worker_config["replicas"] = 1
146
147
148
149
150

        return config

    @classmethod
    def set_config_tp_size(cls, config: dict, tp_size: int):
151
152
        config = deepcopy(config)

153
154
155
156
157
158
        config["spec"]["services"][
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
        ]["resources"]["requests"]["gpu"] = str(tp_size)
        config["spec"]["services"][
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
        ]["resources"]["limits"]["gpu"] = str(tp_size)
159

160
161
162
        args = config["spec"]["services"][
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
        ]["extraPodSpec"]["mainContainer"]["args"]
163
164
165
166
167
168
169
170
171

        args = break_arguments(args)

        try:
            idx = args.index("--tensor-parallel-size")
            args[idx + 1] = str(tp_size)
        except ValueError:
            args = append_argument(args, ["--tensor-parallel-size", str(tp_size)])

172
173
174
        config["spec"]["services"][
            WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
        ]["extraPodSpec"]["mainContainer"]["args"] = join_arguments(args)
175

176
177
178
179
        return config

    @classmethod
    def get_model_name(cls, config: dict) -> str:
180
        worker_name = WORKER_COMPONENT_NAMES["vllm"].decode_worker_k8s_name
181
182
183
184
185
186
187
188
189
190
191
192
193
        args = config["spec"]["services"][worker_name]["extraPodSpec"]["mainContainer"][
            "args"
        ]

        args = break_arguments(args)
        for i, arg in enumerate(args):
            if arg == "--model" and i + 1 < len(args):
                return args[i + 1]

        logger.warning(
            f"Model name not found in configuration args, using default model name: {DEFAULT_MODEL_NAME}"
        )
        return DEFAULT_MODEL_NAME
194
195
196

    @classmethod
    def get_port(cls, config: dict) -> int:
197
198
199
200
201
202
203
204
205
206
207
208
        args = config["spec"]["services"]["Frontend"]["extraPodSpec"]["mainContainer"][
            "args"
        ]
        args = break_arguments(args)
        try:
            idx = args.index("--http-port")
            return int(args[idx + 1])
        except ValueError:
            logger.warning(
                f"Port not found in configuration args, using default port: {DYNAMO_RUN_DEFAULT_PORT}"
            )
            return DYNAMO_RUN_DEFAULT_PORT
209
210
211

    @classmethod
    def get_kv_cache_size_from_dynamo_log(cls, dynamo_log_fn: str) -> int:
212
        # TODO
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
        try:
            with open(dynamo_log_fn, "r") as f:
                for line in f:
                    if "Maximum concurrency for" in line:
                        line = line.strip().split("Maximum concurrency for ")[1]
                        token_count = int(
                            line.split(" tokens per request: ")[0].replace(",", "")
                        )
                        concurrency = float(line.split(" tokens per request: ")[1][:-1])

                        logger.info(
                            f"Found KV cache info: {token_count} x {concurrency} = {int(token_count * concurrency)}"
                        )
                        return int(token_count * concurrency)
        except Exception as e:
            logger.warning(
                f"Failed to parse KV cache size from line: {line}. Error: {e}"
            )
        return 0


CONFIG_MODIFIERS = {
235
    "vllm": VllmV1ConfigModifier,
236
}