kv_router.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
18
import logging
19
20
import random
from argparse import Namespace
21
from typing import AsyncIterator, Tuple
22

23
import numpy as np  # Add numpy import
24
from components.worker import VllmWorker
25
from utils.check_worker import check_required_workers
26
from utils.protocol import LocalBlockHashes
27
from utils.vllm import RouterType
28
29

from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
30
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
31
32
33
from dynamo.sdk.lib.config import ServiceConfig

WorkerId = str
34
fallback_msg = "Will fallback to random routing."
35

36
37
logger = logging.getLogger(__name__)

38

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def softmax_sample_from_logits(
    logits: dict[str, float], temperature: float = 1.0, lower_is_better: bool = True
) -> str:
    if not logits:
        raise ValueError("Empty logits dictionary")

    keys = list(logits.keys())
    values = np.array(list(logits.values()))

    min_val = np.min(values)
    max_val = np.max(values)

    if min_val == max_val:
        # All values are the same, uniform probability
        probabilities = np.ones(len(keys)) / len(keys)
    else:
        normalized = values / (max_val - min_val)
        if lower_is_better:
            normalized = -1 * normalized

        scaled = normalized / temperature

        exp_values = np.exp(scaled - np.max(scaled))
        probabilities = exp_values / np.sum(exp_values)

    # Sample from the probability distribution
    return np.random.choice(keys, p=probabilities)


68
69
70
def parse_args(service_name, prefix) -> Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
71
        "--model",
72
73
74
75
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        help="Model that is being served",
    )
76
77
78
79
80
81
    parser.add_argument(
        "--min-workers",
        type=int,
        default=1,
        help="Minimum number of workers required before proceeding",
    )
82
83
84
85
86
87
88
89
90
91
92
93
94
    # TODO: Read block size
    parser.add_argument(
        "--block-size",
        type=int,
        default=64,
        help="KV block size",
    )
    parser.add_argument(
        "--custom-router",
        type=bool,
        default=False,
        help="Whether to use custom router or not",
    )
95
96
97
98
99
100
    parser.add_argument(
        "--router",
        type=str,
        default="kv",
        help="The router type",
    )
101
102
103
104
105
106
    parser.add_argument(
        "--softmax-sample",
        type=bool,
        default=False,
        help="Whether to do softmax sampling based on worker logits (default is to pick smallest)",
    )
107
108
109
110
111
112
113
114
    config = ServiceConfig.get_instance()
    config_args = config.as_args(service_name, prefix=prefix)
    args = parser.parse_args(config_args)
    return args


@service(
    dynamo={
115
        "namespace": "dynamo",
116
117
118
119
120
121
122
123
124
    },
    resources={"cpu": "10", "memory": "20Gi"},
    workers=1,
)
class Router:
    """
    Request handler for the generate endpoint
    """

125
126
    worker = depends(VllmWorker)

127
    def __init__(self):
128
        logger.info("Initializing Custom Router")
129
130
        self.args = parse_args(self.__class__.__name__, "")

131
        self.default_metrics = {
132
133
            "kv_active_blocks": 0,
            "kv_total_blocks": 1,
134
            "num_requests_waiting": 0.0,
135
            "gpu_cache_usage_perc": 0.0,
136
137
138
            "gpu_prefix_cache_hit_rate": 0.0,
        }

139
    @async_on_start
140
141
142
    async def async_init(self):
        self.runtime = dynamo_context["runtime"]
        self.workers_client = (
143
            await self.runtime.namespace("dynamo")
144
145
146
147
            .component("VllmWorker")
            .endpoint("generate")
            .client()
        )
148

149
150
        self.router_type = self.args.router

151
        await check_required_workers(self.workers_client, self.args.min_workers)
152

153
        kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
154
        await kv_listener.create_service()
155
156
        if self.router_type == RouterType.KV:
            self.indexer = KvIndexer(kv_listener, self.args.block_size)
157
        self.metrics_aggregator = KvMetricsAggregator(kv_listener)
158
159
160
161
162
163
164

        self.active_blocks_dict = {}
        worker_ids = self.workers_client.instance_ids()
        for worker_id in worker_ids:
            # [old_value, predictive_value]
            self.active_blocks_dict[worker_id] = [0, 0]

165
        logger.info("KV Router initialized")
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def _update_and_get_active_blocks(self, worker_id: str, polled_value: int) -> int:
        """Helper routine to update waiting dict and return the desired waiting value.

        This method implements a predictive mechanism for tracking waiting requests:
        - If a new polled value is detected (different from the stored old value),
          it updates both the old and predictive values to this new measurement and returns it
        - If no change is detected (polled value equals old value), it returns the
          predictive value which has been incremented based on previous routing decisions

        This allows the router to account for requests that have been dispatched but
        not yet reflected in the polled metrics.
        """
        old_value, predictive_value = self.active_blocks_dict[worker_id]

        # Check if polled value is different from old value
        if polled_value != old_value:
            self.active_blocks_dict[worker_id] = [polled_value, polled_value]
            return polled_value
        else:
            return predictive_value

188
189
190
191
192
193
    def _cost_function(
        self,
        scores: OverlapScores | None,
        metrics: AggregatedMetrics | None,
        token_length: int,
    ):
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        """The cost function for deciding the best worker to route a request to.
        If there are multiple workers sharing the same optimal cost, then
        one of them is randomly selected.

        Args:
            scores (OverlapScores | None): The number of matching blocks between
                the request and the prefix cache of each worker.
            metrics (AggregatedMetrics | None): Several worker metrics polled
                by the `KvMetricsAggregator`, currently including the
                GPU cache usage, number of waiting requests, and the
                GPU prefix cache hit rate.
            token_length (int): The number of tokens in the request.

        Returns:
            (str, float): The best worker id and the corresponding score.
        """

211
212
213
214
215
216
217
218
219
220
        # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
        # and we want all workers to be considered in the logit calculation
        worker_ids = self.workers_client.instance_ids()
        request_blocks = (
            token_length + self.args.block_size - 1
        ) // self.args.block_size

        overlap_blocks_dict = {worker_id: 0 for worker_id in worker_ids}
        new_blocks_dict = {worker_id: request_blocks for worker_id in worker_ids}

221
222
223
224
        if scores:
            for worker_id, score in scores.scores.items():
                # score is number of matching blocks we multiply by block_size to get tokens
                # and compare to token_length. The larger the cache hit the better
225
226
                overlap_blocks_dict[worker_id] = score
                new_blocks_dict[worker_id] = request_blocks - score
227
228
        else:
            logger.warning("Cannot get KV scores")
229
230
231
232
233

        worker_metrics = {}
        if metrics:
            for endpoint in metrics.endpoints:
                worker_id = endpoint.worker_id
234
                worker_metrics[worker_id] = {
235
236
                    key: getattr(endpoint, key, self.default_metrics[key])
                    for key in self.default_metrics.keys()
237
                }
238
239
240
241

                # Update waiting value using helper routine
                polled_active_blocks = int(
                    worker_metrics[worker_id]["kv_active_blocks"]
242
                )
243
244
245
                worker_metrics[worker_id][
                    "kv_active_blocks"
                ] = self._update_and_get_active_blocks(worker_id, polled_active_blocks)
246
247
        else:
            logger.warning("Cannot get metrics")
248
249
250
251

        worker_logits = {}
        for worker_id in worker_ids:
            # Use default values if worker not in scores or metrics
252
            metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
253
            kv_total_blocks = metrics_dict["kv_total_blocks"]
254

255
256
257
258
259
260
            new_blocks = new_blocks_dict[worker_id]
            normalized_new_blocks = new_blocks / kv_total_blocks
            gpu_cache_usage = metrics_dict["kv_active_blocks"] / kv_total_blocks

            # Use raw waiting value without normalization
            num_requests_waiting = metrics_dict["num_requests_waiting"]
261
262
263

            # Have 1 metric that weights towards cache hit
            # 2 metrics that penalize overloaded worker and queuing
264
265
266
            worker_logits[worker_id] = (
                normalized_new_blocks + gpu_cache_usage + num_requests_waiting
            )
267
            logger.info(
268
                f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = {normalized_new_blocks:.3f} + {gpu_cache_usage:.3f} + {num_requests_waiting:.3f}"
269
270
            )

271
272
        if not worker_logits or not any(worker_logits.values()):
            logger.warning(f"All worker logits are zero. {fallback_msg}.")
273
            return "", 0.0
274
275

        # Select the worker with the highest logit
276
277
278
279
280
281
282
283
        if self.args.softmax_sample:
            best_worker_id = int(softmax_sample_from_logits(worker_logits))
        else:
            min_logit = min(worker_logits.values())
            best_workers = [
                wid for wid, logit in worker_logits.items() if logit == min_logit
            ]
            best_worker_id = random.choice(best_workers)
284
285
286

        # Log the metrics for the selected worker
        if best_worker_id:
287
            metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
288

289
290
291
292
293
            # Create log messages
            log_messages = [
                f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
                f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
                f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
294
                f"GPU Cache Usage: {metrics_dict['kv_active_blocks'] / metrics_dict['kv_total_blocks']:.3f}",
295
296
297
298
299
                f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
            ]

            # Log to vllm_logger
            for message in log_messages:
300
                logger.info(message)
301

302
303
304
305
306
307
308
309
310
            # Increment predictive waiting for the selected worker before returning
            self.active_blocks_dict[best_worker_id][1] += new_blocks_dict[
                best_worker_id
            ]

        return (
            best_worker_id,
            overlap_blocks_dict[best_worker_id] * self.args.block_size / token_length,
        )
311

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
        if not metrics:
            logger.warning(f"Cannot get metrics. {fallback_msg}")
            return "", 0.0

        kv_load = {
            endpoint.worker_id: getattr(endpoint, "gpu_cache_usage_perc", 0.0)
            for endpoint in metrics.endpoints
        }

        if not kv_load or not any(kv_load.values()):
            logger.warning(f"All KV loads are zero. {fallback_msg}")
            return "", 0.0

        min_load = min(kv_load.values())
        min_load_workers = [
            worker_id for worker_id, load in kv_load.items() if load == min_load
        ]
        best_worker_id = random.choice(min_load_workers)

        logger.info(
            f"Selected worker: {best_worker_id}, KV load: {kv_load[best_worker_id]:.3f}"
        )
        return best_worker_id, kv_load[best_worker_id]

337
    @endpoint()
338
339
340
    async def generate(
        self, request: LocalBlockHashes
    ) -> AsyncIterator[Tuple[WorkerId, float]]:
341
342
343
344
345
346
347
348
349
350
351
352
353
354
        metrics = await self.metrics_aggregator.get_metrics()

        # Quick return for KV_LOAD mode
        if self.router_type == RouterType.KV_LOAD:
            try:
                yield self._get_underloaded_worker(metrics)
            except Exception as e:
                logger.exception(
                    f"Error finding underloaded worker: {e}. {fallback_msg}"
                )
                yield "", 0.0
            return

        # Existing KV routing logic
355
        try:
356
            scores = await self.indexer.find_matches(request.hashes)
357
358
        except Exception as e:
            scores = {}
359
360
361
            logger.exception(f"Error finding matches: {e}. {fallback_msg}")
            yield "", 0.0
            return
362

363
        worker_id, prefix_hit_rate = self._cost_function(
364
            scores, metrics, request.num_tokens
365
        )
366

367
368
369
370
371
372
        if worker_id:
            logger.info(
                f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
            )

        yield worker_id, prefix_hit_rate