kv_router.py 8.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import asyncio
18
import random
19
20
21
22
23
24
25
from argparse import Namespace
from typing import AsyncIterator

import uvloop
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger

26
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
27
28
29
30
31
32
33
34
35
36
37
38
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker

WorkerId = str


class CustomRouter:
    """
    Request handler for the generate endpoint
    """

    def __init__(
        self,
39
        workers_client,
40
41
42
        indexer: KvIndexer,
        metrics_aggregator: KvMetricsAggregator,
    ):
43
        vllm_logger.info("Initializing Custom Router")
44
45
        self.indexer = indexer
        self.metrics_aggregator = metrics_aggregator
46
        self.workers_client = workers_client
47

48
49
50
51
52
53
54
55
56
57
58
59
60
    def _cost_function(
        self,
        scores: OverlapScores | None,
        metrics: AggregatedMetrics | None,
        token_length: int,
    ):
        worker_scores = {}
        if scores:
            for worker_id, score in scores.scores.items():
                # score is number of matching blocks we multiply by block_size to get tokens
                # and compare to token_length. The larger the cache hit the better
                worker_scores[worker_id] = (
                    score * self.indexer.block_size() / token_length
61
                )
62
63
64
65
66
67
68

        worker_metrics = {}
        # pull metrics for each worker
        max_waiting = 0.0
        if metrics:
            for endpoint in metrics.endpoints:
                worker_id = endpoint.worker_id
69
70
71
72
73
74
75
76
77
78
79
80
81
82
                worker_metrics[worker_id] = {
                    "gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
                    if hasattr(endpoint, "gpu_cache_usage_perc")
                    else 0.0,
                    "num_requests_waiting": endpoint.num_requests_waiting
                    if hasattr(endpoint, "num_requests_waiting")
                    else 0.0,
                    "gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
                    if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
                    else 0.0,
                }
                max_waiting = max(
                    max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
                )
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        # Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
        # and we want all workers to be considered in the logit calculation
        worker_ids = self.workers_client.endpoint_ids()

        worker_logits = {}
        for worker_id in worker_ids:
            # Use default values if worker not in scores or metrics
            score = worker_scores.get(worker_id, 0.0)
            metrics_dict = worker_metrics.get(
                worker_id,
                {
                    "gpu_cache_usage_perc": 0.0,
                    "num_requests_waiting": 0.0,
                    "gpu_prefix_cache_hit_rate": 0.0,
                },
            )

            normalized_waiting = (
                metrics_dict["num_requests_waiting"] / max_waiting
                if max_waiting > 0
                else 0.0
            )

            # Have 1 metric that weights towards cache hit
            # 2 metrics that penalize overloaded worker and queuing
            worker_logits[worker_id] = (
                2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
            )
            vllm_logger.info(
                f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
            )

        if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
            return ""

        # Select the worker with the highest logit
        if worker_logits:
            max_logit = max(worker_logits.values())
            best_workers = [
                wid for wid, logit in worker_logits.items() if logit == max_logit
            ]
            best_worker_id = random.choice(best_workers)
        else:
            best_worker_id = ""

        # Log the metrics for the selected worker
        if best_worker_id:
            vllm_logger.info(
                f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
            )
            vllm_logger.info(
                f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
            )

            metrics_dict = worker_metrics.get(best_worker_id, {})
            vllm_logger.info(
                f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
            )
            vllm_logger.info(
                f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
            )
            vllm_logger.info(
                f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
            )

        return best_worker_id, worker_scores.get(best_worker_id, 0.0)
150
151
152
153
154
155
156
157
158
159

    @dynamo_endpoint(Tokens, WorkerId)
    async def generate(self, request) -> AsyncIterator[WorkerId]:
        lora_id = 0
        worker_id = ""
        try:
            scores = await self.indexer.find_matches_for_request(
                request.tokens, lora_id
            )
        except Exception as e:
160
161
162
163
164
165
166
            scores = {}
            vllm_logger.exception(f"Error finding matches: {e}")

        token_length = len(request.tokens)
        metrics = await self.metrics_aggregator.get_metrics()
        schedule_result = self._cost_function(scores, metrics, token_length)
        if schedule_result == "":
167
            worker_id = ""
168
169
170
            prefix_hit_rate = 0.0
        else:
            worker_id, prefix_hit_rate = schedule_result
171

172
173
174
        vllm_logger.info(
            f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
        )
175

176
        yield f"{worker_id}_{prefix_hit_rate}"
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210


@dynamo_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
    """
    Set up the worker clients.
    Serve the dynamo-init.router.generate endpoint.
    """
    workers_client = (
        await runtime.namespace("dynamo-init")
        .component("vllm")
        .endpoint("generate")
        .client()
    )

    while len(workers_client.endpoint_ids()) < args.min_workers:
        vllm_logger.info(
            f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}"
        )
        await asyncio.sleep(5)

    vllm_logger.info(
        f"Required number of workers ({args.min_workers}) are ready:\n"
        + "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
    )

    kv_listener = runtime.namespace("dynamo-init").component("vllm")
    await kv_listener.create_service()

    router_component = runtime.namespace("dynamo-init").component("router")
    await router_component.create_service()

    endpoint = router_component.endpoint("generate")

211
212
213
214
215
    indexer = KvIndexer(kv_listener, args.block_size)
    metrics_aggregator = KvMetricsAggregator(kv_listener)
    await endpoint.serve_endpoint(
        CustomRouter(workers_client, indexer, metrics_aggregator).generate
    )
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


if __name__ == "__main__":
    uvloop.install()

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--min-workers",
        type=int,
        default=1,
        help="Minimum number of workers required before proceeding",
    )
    parser.add_argument(
        "--model-name",
        type=str,
        default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
        help="Model that is being served",
    )
236
    # TODO: Read block size
237
238
239
    parser.add_argument(
        "--block-size",
        type=int,
240
        default=64,
241
242
        help="KV block size",
    )
243
244
245
246
247
248
249
250
251
    parser.add_argument(
        "--custom-router",
        type=bool,
        default=False,
        help="Whether to use custom router or not",
    )
    args = parser.parse_args()

    asyncio.run(worker(args))