kv_router.py 14.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
18
import logging
19
20
import random
from argparse import Namespace
21
from typing import AsyncIterator, Tuple
22

23
import numpy as np  # Add numpy import
24
from components.worker import VllmWorker
25
from utils.check_worker import check_required_workers
26
from utils.protocol import LocalBlockHashes
27
from utils.vllm import RouterType
28

29
30
31
32
33
34
35
from dynamo.llm import (
    AggregatedMetrics,
    ApproxKvIndexer,
    KvIndexer,
    KvMetricsAggregator,
    OverlapScores,
)
36
from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
37
38
39
from dynamo.sdk.lib.config import ServiceConfig

WorkerId = str
40
fallback_msg = "Will fallback to random routing."
41

42
43
logger = logging.getLogger(__name__)

44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def softmax_sample_from_logits(
    logits: dict[str, float], temperature: float = 1.0, lower_is_better: bool = True
) -> str:
    if not logits:
        raise ValueError("Empty logits dictionary")

    keys = list(logits.keys())
    values = np.array(list(logits.values()))

    min_val = np.min(values)
    max_val = np.max(values)

    if min_val == max_val:
        # All values are the same, uniform probability
        probabilities = np.ones(len(keys)) / len(keys)
    else:
        normalized = values / (max_val - min_val)
        if lower_is_better:
            normalized = -1 * normalized

        scaled = normalized / temperature

        exp_values = np.exp(scaled - np.max(scaled))
        probabilities = exp_values / np.sum(exp_values)

    # Sample from the probability distribution
    return np.random.choice(keys, p=probabilities)


74
75
76
def parse_args(service_name, prefix) -> Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
77
        "--model",
78
79
80
81
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        help="Model that is being served",
    )
82
83
84
85
86
87
    parser.add_argument(
        "--min-workers",
        type=int,
        default=1,
        help="Minimum number of workers required before proceeding",
    )
88
89
90
91
92
93
94
95
96
97
98
99
100
    # TODO: Read block size
    parser.add_argument(
        "--block-size",
        type=int,
        default=64,
        help="KV block size",
    )
    parser.add_argument(
        "--custom-router",
        type=bool,
        default=False,
        help="Whether to use custom router or not",
    )
101
102
103
104
105
106
    parser.add_argument(
        "--router",
        type=str,
        default="kv",
        help="The router type",
    )
107
108
    parser.add_argument(
        "--softmax-sample",
109
        action="store_true",
110
111
        help="Whether to do softmax sampling based on worker logits (default is to pick smallest)",
    )
112
113
114
115
116
117
118
119
    config = ServiceConfig.get_instance()
    config_args = config.as_args(service_name, prefix=prefix)
    args = parser.parse_args(config_args)
    return args


@service(
    dynamo={
120
        "namespace": "dynamo",
121
122
123
124
125
126
127
128
129
    },
    resources={"cpu": "10", "memory": "20Gi"},
    workers=1,
)
class Router:
    """
    Request handler for the generate endpoint
    """

130
131
    worker = depends(VllmWorker)

132
    def __init__(self):
133
        logger.info("Initializing Custom Router")
134
135
        self.args = parse_args(self.__class__.__name__, "")

136
        self.default_metrics = {
137
138
            "kv_active_blocks": 0,
            "kv_total_blocks": 1,
139
            "num_requests_waiting": 0.0,
140
            "gpu_cache_usage_perc": 0.0,
141
142
143
            "gpu_prefix_cache_hit_rate": 0.0,
        }

144
    @async_on_start
145
146
147
    async def async_init(self):
        self.runtime = dynamo_context["runtime"]
        self.workers_client = (
148
            await self.runtime.namespace("dynamo")
149
150
151
152
            .component("VllmWorker")
            .endpoint("generate")
            .client()
        )
153

154
155
        self.router_type = self.args.router

156
        await check_required_workers(self.workers_client, self.args.min_workers)
157

158
        kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
159
        await kv_listener.create_service()
160
161
        if self.router_type == RouterType.KV:
            self.indexer = KvIndexer(kv_listener, self.args.block_size)
162
163
164
165
        elif self.router_type == RouterType.APPROX_KV:
            # For now, hardcode the TTL to 2 minutes.
            self.indexer = ApproxKvIndexer(kv_listener, self.args.block_size, 120.0)

166
        self.metrics_aggregator = KvMetricsAggregator(kv_listener)
167
168
169
170
171
172
173

        self.active_blocks_dict = {}
        worker_ids = self.workers_client.instance_ids()
        for worker_id in worker_ids:
            # [old_value, predictive_value]
            self.active_blocks_dict[worker_id] = [0, 0]

174
        logger.info("KV Router initialized")
175

176
177
178
179
180
181
182
183
184
185
186
187
    def _update_and_get_active_blocks(self, worker_id: str, polled_value: int) -> int:
        """Helper routine to update waiting dict and return the desired waiting value.

        This method implements a predictive mechanism for tracking waiting requests:
        - If a new polled value is detected (different from the stored old value),
          it updates both the old and predictive values to this new measurement and returns it
        - If no change is detected (polled value equals old value), it returns the
          predictive value which has been incremented based on previous routing decisions

        This allows the router to account for requests that have been dispatched but
        not yet reflected in the polled metrics.
        """
188
189
190
191
192
193
        # Initialize if worker_id is not present
        if worker_id not in self.active_blocks_dict:
            logger.warning(f"New Worker added: {worker_id}")
            self.active_blocks_dict[worker_id] = [polled_value, polled_value]
            return polled_value

194
195
196
197
198
199
200
201
202
        old_value, predictive_value = self.active_blocks_dict[worker_id]

        # Check if polled value is different from old value
        if polled_value != old_value:
            self.active_blocks_dict[worker_id] = [polled_value, polled_value]
            return polled_value
        else:
            return predictive_value

203
204
205
206
207
208
    def _cost_function(
        self,
        scores: OverlapScores | None,
        metrics: AggregatedMetrics | None,
        token_length: int,
    ):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
        """The cost function for deciding the best worker to route a request to.
        If there are multiple workers sharing the same optimal cost, then
        one of them is randomly selected.

        Args:
            scores (OverlapScores | None): The number of matching blocks between
                the request and the prefix cache of each worker.
            metrics (AggregatedMetrics | None): Several worker metrics polled
                by the `KvMetricsAggregator`, currently including the
                GPU cache usage, number of waiting requests, and the
                GPU prefix cache hit rate.
            token_length (int): The number of tokens in the request.

        Returns:
            (str, float): The best worker id and the corresponding score.
        """

226
227
228
229
230
231
232
233
234
235
        # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
        # and we want all workers to be considered in the logit calculation
        worker_ids = self.workers_client.instance_ids()
        request_blocks = (
            token_length + self.args.block_size - 1
        ) // self.args.block_size

        overlap_blocks_dict = {worker_id: 0 for worker_id in worker_ids}
        new_blocks_dict = {worker_id: request_blocks for worker_id in worker_ids}

236
237
238
239
        if scores:
            for worker_id, score in scores.scores.items():
                # score is number of matching blocks we multiply by block_size to get tokens
                # and compare to token_length. The larger the cache hit the better
240
241
                overlap_blocks_dict[worker_id] = score
                new_blocks_dict[worker_id] = request_blocks - score
242
243
        else:
            logger.warning("Cannot get KV scores")
244
245
246
247
248

        worker_metrics = {}
        if metrics:
            for endpoint in metrics.endpoints:
                worker_id = endpoint.worker_id
249
                worker_metrics[worker_id] = {
250
251
                    key: getattr(endpoint, key, self.default_metrics[key])
                    for key in self.default_metrics.keys()
252
                }
253
254
255
256

                # Update waiting value using helper routine
                polled_active_blocks = int(
                    worker_metrics[worker_id]["kv_active_blocks"]
257
                )
258
259
260
                worker_metrics[worker_id][
                    "kv_active_blocks"
                ] = self._update_and_get_active_blocks(worker_id, polled_active_blocks)
261
262
        else:
            logger.warning("Cannot get metrics")
263
264
265
266

        worker_logits = {}
        for worker_id in worker_ids:
            # Use default values if worker not in scores or metrics
267
            metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
268
            kv_total_blocks = metrics_dict["kv_total_blocks"]
269

270
271
272
273
274
275
            new_blocks = new_blocks_dict[worker_id]
            normalized_new_blocks = new_blocks / kv_total_blocks
            gpu_cache_usage = metrics_dict["kv_active_blocks"] / kv_total_blocks

            # Use raw waiting value without normalization
            num_requests_waiting = metrics_dict["num_requests_waiting"]
276
277
278

            # Have 1 metric that weights towards cache hit
            # 2 metrics that penalize overloaded worker and queuing
279
280
281
            worker_logits[worker_id] = (
                normalized_new_blocks + gpu_cache_usage + num_requests_waiting
            )
282
            logger.info(
283
                f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = {normalized_new_blocks:.3f} + {gpu_cache_usage:.3f} + {num_requests_waiting:.3f}"
284
285
            )

286
287
        if not worker_logits or not any(worker_logits.values()):
            logger.warning(f"All worker logits are zero. {fallback_msg}.")
288
            return "", 0.0
289
290

        # Select the worker with the highest logit
291
292
293
294
295
296
297
298
        if self.args.softmax_sample:
            best_worker_id = int(softmax_sample_from_logits(worker_logits))
        else:
            min_logit = min(worker_logits.values())
            best_workers = [
                wid for wid, logit in worker_logits.items() if logit == min_logit
            ]
            best_worker_id = random.choice(best_workers)
299
300
301

        # Log the metrics for the selected worker
        if best_worker_id:
302
            metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
303

304
305
306
307
308
            # Create log messages
            log_messages = [
                f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
                f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
                f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
309
                f"GPU Cache Usage: {metrics_dict['kv_active_blocks'] / metrics_dict['kv_total_blocks']:.3f}",
310
311
312
313
314
                f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
            ]

            # Log to vllm_logger
            for message in log_messages:
315
                logger.info(message)
316

317
318
319
320
321
322
323
324
325
            # Increment predictive waiting for the selected worker before returning
            self.active_blocks_dict[best_worker_id][1] += new_blocks_dict[
                best_worker_id
            ]

        return (
            best_worker_id,
            overlap_blocks_dict[best_worker_id] * self.args.block_size / token_length,
        )
326

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    def _get_underloaded_worker(self, metrics: AggregatedMetrics | None):
        if not metrics:
            logger.warning(f"Cannot get metrics. {fallback_msg}")
            return "", 0.0

        kv_load = {
            endpoint.worker_id: getattr(endpoint, "gpu_cache_usage_perc", 0.0)
            for endpoint in metrics.endpoints
        }

        if not kv_load or not any(kv_load.values()):
            logger.warning(f"All KV loads are zero. {fallback_msg}")
            return "", 0.0

        min_load = min(kv_load.values())
        min_load_workers = [
            worker_id for worker_id, load in kv_load.items() if load == min_load
        ]
        best_worker_id = random.choice(min_load_workers)

        logger.info(
            f"Selected worker: {best_worker_id}, KV load: {kv_load[best_worker_id]:.3f}"
        )
        return best_worker_id, kv_load[best_worker_id]

352
    @endpoint()
353
354
355
    async def generate(
        self, request: LocalBlockHashes
    ) -> AsyncIterator[Tuple[WorkerId, float]]:
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        metrics = await self.metrics_aggregator.get_metrics()

        # Quick return for KV_LOAD mode
        if self.router_type == RouterType.KV_LOAD:
            try:
                yield self._get_underloaded_worker(metrics)
            except Exception as e:
                logger.exception(
                    f"Error finding underloaded worker: {e}. {fallback_msg}"
                )
                yield "", 0.0
            return

        # Existing KV routing logic
370
        try:
371
372
373
374
            if self.router_type == RouterType.APPROX_KV:
                scores = await self.indexer.find_matches_for_request(request.tokens)
            else:
                scores = await self.indexer.find_matches(request.hashes)
375
376
        except Exception as e:
            scores = {}
377
378
379
            logger.exception(f"Error finding matches: {e}. {fallback_msg}")
            yield "", 0.0
            return
380

381
        worker_id, prefix_hit_rate = self._cost_function(
382
            scores, metrics, request.num_tokens
383
        )
384

385
386
387
388
389
390
391
392
393
394
        if self.router_type == RouterType.APPROX_KV:
            # For the approx kv router, we need to know what worker we route to.
            # We can't defer to the engine client to select a random worker.
            # Because of this, we need to select a worker here.
            if not worker_id:
                all_workers = self.workers_client.instance_ids()
                worker_id = random.choice(all_workers)

            await self.log_router_decision(request.tokens, worker_id)

395
396
397
398
399
400
        if worker_id:
            logger.info(
                f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
            )

        yield worker_id, prefix_hit_rate
401
402
403
404
405
406
407
408
409
410
411

    async def log_router_decision(self, tokens: list[int], worker_id: str):
        if self.router_type == RouterType.APPROX_KV:
            try:
                await self.indexer.process_routing_decision_for_request(
                    tokens, worker_id
                )
            except Exception as e:
                logger.exception(
                    f"Error processing routing decision: {e}. {fallback_msg}"
                )