kv_router.py 8.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
18
import logging
19
20
21
22
import random
from argparse import Namespace
from typing import AsyncIterator

23
from components.worker import VllmWorker
24
from utils.logging import check_required_workers
25
26
27
from utils.protocol import Tokens

from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
28
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
29
30
31
32
from dynamo.sdk.lib.config import ServiceConfig

WorkerId = str

33
34
logger = logging.getLogger(__name__)

35
36
37
38
39
40
41
42
43
44

def parse_args(service_name, prefix) -> Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--min-workers",
        type=int,
        default=1,
        help="Minimum number of workers required before proceeding",
    )
    parser.add_argument(
45
        "--model",
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        help="Model that is being served",
    )
    # TODO: Read block size
    parser.add_argument(
        "--block-size",
        type=int,
        default=64,
        help="KV block size",
    )
    parser.add_argument(
        "--custom-router",
        type=bool,
        default=False,
        help="Whether to use custom router or not",
    )
    config = ServiceConfig.get_instance()
    config_args = config.as_args(service_name, prefix=prefix)
    args = parser.parse_args(config_args)
    return args


@service(
    dynamo={
        "enabled": True,
72
        "namespace": "dynamo",
73
74
75
76
77
78
79
80
81
    },
    resources={"cpu": "10", "memory": "20Gi"},
    workers=1,
)
class Router:
    """
    Request handler for the generate endpoint
    """

82
83
    worker = depends(VllmWorker)

84
    def __init__(self):
85
        logger.info("Initializing Custom Router")
86
87
        self.args = parse_args(self.__class__.__name__, "")

88
89
90
91
92
93
        self.default_metrics = {
            "gpu_cache_usage_perc": 0.0,
            "num_requests_waiting": 0.0,
            "gpu_prefix_cache_hit_rate": 0.0,
        }

94
    @async_on_start
95
96
97
    async def async_init(self):
        self.runtime = dynamo_context["runtime"]
        self.workers_client = (
98
            await self.runtime.namespace("dynamo")
99
100
101
102
            .component("VllmWorker")
            .endpoint("generate")
            .client()
        )
103
104

        await check_required_workers(self.workers_client, self.args.min_workers)
105

106
        kv_listener = self.runtime.namespace("dynamo").component("VllmWorker")
107
108
109
        await kv_listener.create_service()
        self.indexer = KvIndexer(kv_listener, self.args.block_size)
        self.metrics_aggregator = KvMetricsAggregator(kv_listener)
110
        logger.info("KV Router initialized")
111
112
113
114
115
116
117

    def _cost_function(
        self,
        scores: OverlapScores | None,
        metrics: AggregatedMetrics | None,
        token_length: int,
    ):
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        """The cost function for deciding the best worker to route a request to.
        If there are multiple workers sharing the same optimal cost, then
        one of them is randomly selected.

        Args:
            scores (OverlapScores | None): The number of matching blocks between
                the request and the prefix cache of each worker.
            metrics (AggregatedMetrics | None): Several worker metrics polled
                by the `KvMetricsAggregator`, currently including the
                GPU cache usage, number of waiting requests, and the
                GPU prefix cache hit rate.
            token_length (int): The number of tokens in the request.

        Returns:
            (str, float): The best worker id and the corresponding score.
        """

135
136
137
138
139
140
141
142
        worker_scores = {}
        if scores:
            for worker_id, score in scores.scores.items():
                # score is number of matching blocks we multiply by block_size to get tokens
                # and compare to token_length. The larger the cache hit the better
                worker_scores[worker_id] = (
                    score * self.indexer.block_size() / token_length
                )
143
144
        else:
            logger.warning("Cannot get KV scores")
145
146
147
148
149
150

        worker_metrics = {}
        max_waiting = 0.0
        if metrics:
            for endpoint in metrics.endpoints:
                worker_id = endpoint.worker_id
151
                worker_metrics[worker_id] = {
152
153
                    key: getattr(endpoint, key, self.default_metrics[key])
                    for key in self.default_metrics.keys()
154
155
156
157
                }
                max_waiting = max(
                    max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
                )
158
159
        else:
            logger.warning("Cannot get metrics")
160
161
162
163
164
165
166
167
168

        # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
        # and we want all workers to be considered in the logit calculation
        worker_ids = self.workers_client.endpoint_ids()

        worker_logits = {}
        for worker_id in worker_ids:
            # Use default values if worker not in scores or metrics
            score = worker_scores.get(worker_id, 0.0)
169
170
            metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
            gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
171
172
173
174
175
176
177
178
179

            normalized_waiting = (
                metrics_dict["num_requests_waiting"] / max_waiting
                if max_waiting > 0
                else 0.0
            )

            # Have 1 metric that weights towards cache hit
            # 2 metrics that penalize overloaded worker and queuing
180
            worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
181
            logger.info(
182
                f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
183
184
185
            )

        if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
186
            return "", 0.0
187
188

        # Select the worker with the highest logit
189
190
191
192
193
        max_logit = max(worker_logits.values())
        best_workers = [
            wid for wid, logit in worker_logits.items() if logit == max_logit
        ]
        best_worker_id = random.choice(best_workers)
194
195
196

        # Log the metrics for the selected worker
        if best_worker_id:
197
            metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
198

199
200
201
202
203
204
205
206
207
208
209
            # Create log messages
            log_messages = [
                f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
                f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
                f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
                f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
                f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
            ]

            # Log to vllm_logger
            for message in log_messages:
210
                logger.info(message)
211
212
213
214
215
216
217
218
219
220
221
222

        return best_worker_id, worker_scores.get(best_worker_id, 0.0)

    @dynamo_endpoint()
    async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
        lora_id = 0
        try:
            scores = await self.indexer.find_matches_for_request(
                request.tokens, lora_id
            )
        except Exception as e:
            scores = {}
223
            logger.exception(f"Error finding matches: {e}")
224
225

        metrics = await self.metrics_aggregator.get_metrics()
226
227
228
        worker_id, prefix_hit_rate = self._cost_function(
            scores, metrics, len(request.tokens)
        )
229

230
        logger.info(
231
232
233
            f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
        )
        yield f"{worker_id}_{prefix_hit_rate}"