kv_router.rs 7.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
17
use std::sync::Arc;

18
use anyhow::Result;
19
use dynamo_runtime::{
20
    component::{Component, EndpointSource},
21
    pipeline::{
22
23
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
        ResponseStream, SingleIn,
24
25
26
27
28
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
29
30

pub mod indexer;
31
pub mod metrics_aggregator;
32
33
pub mod protocols;
pub mod publisher;
34
pub mod recorder;
35
36
pub mod scheduler;
pub mod scoring;
37

38
39
40
41
42
43
44
45
use crate::{
    kv_router::{
        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
        metrics_aggregator::KvMetricsAggregator,
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
46
47
    preprocessor::BackendInput,
    protocols::common::llm_backend::LLMEngineOutput,
Ryan Olson's avatar
Ryan Olson committed
48
    tokens::TokenBlockSequence,
49
50
};

51
use dynamo_runtime::traits::events::EventSubscriber;
52

53
54
55
// TODO: Allow user to change
pub const DEFAULT_KV_BLOCK_SIZE: usize = 16;

56
57
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
58
pub const KV_EVENT_SUBJECT: &str = "kv_events";
59
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
60
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
61

62
63
64
65
66
67
68
69
70
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
        block_size: usize,
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
71

72
73
/// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route.
74
pub struct KvRouter {
75
    indexer: KvIndexer,
76
77
    scheduler: KvScheduler,
    block_size: usize,
78
79
80
81
}

impl KvRouter {
    pub async fn new(
82
        component: Component,
83
84
        block_size: usize,
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
85
    ) -> Result<Self> {
86
87
88
89
90
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
91
92
93
94
95
96
97
98
99
100
101

        let metrics_aggregator =
            KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
        let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
102

103
104
105
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
        let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
106
107
108
109
        let kv_events_tx = indexer.event_sender();

        tokio::spawn(async move {
            while let Some(event) = kv_events_rx.next().await {
Alec's avatar
Alec committed
110
                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
111
                    Ok(event) => event,
Alec's avatar
Alec committed
112
113
114
115
116
117
118
                    Err(e) => {
                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                        // Choosing warn and continue to process other events from other workers
                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                        continue;
                    }
                };
119
                if let Err(e) = kv_events_tx.send(event).await {
120
                    tracing::debug!("failed to send kv event to indexer; shutting down: {:?}", e);
121
122
123
124
                }
            }
        });

125
        Ok(Self {
126
127
            scheduler,
            indexer,
128
            block_size,
129
        })
130
131
132
    }

    // [TODO] indexer needs to take 'lora_id' as parameter
GuanLuo's avatar
GuanLuo committed
133
    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
134
135
136
137
138
139
140
        // Extracting part of the code in KvRouter::generate() for only
        // the decision making part, routing is done by the caller
        let isl_tokens = token_ids.len();
        let overlap_scores = self
            .indexer
            .find_matches_for_request(token_ids.as_slice())
            .await?;
Alec's avatar
Alec committed
141
        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
GuanLuo's avatar
GuanLuo committed
142
143
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
        Ok(worker_id)
144
    }
145

146
147
148
    /// Give these tokens, find the worker with the best match in it's KV cache.
    async fn find_best_match(&self, tokens: &[u32]) -> anyhow::Result<i64> {
        let isl_tokens = tokens.len();
149
150
        let block_size = self.block_size;

Ryan Olson's avatar
Ryan Olson committed
151
        let (complete_blocks, _partial_block) =
152
            TokenBlockSequence::split_tokens(tokens, block_size, 1337_u64);
Ryan Olson's avatar
Ryan Olson committed
153
154
155
156
157

        let local_block_hashes = complete_blocks
            .into_iter()
            .map(|block| LocalBlockHash(block.block_hash()))
            .collect();
158
159
        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
160
161
162
163
164
165
166
167
168
169
170
171
        Ok(worker_id)
    }
}

#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
        let worker_id = self.find_best_match(&request.tokens).await?;
172
173
174
175
176
177
178

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

pub struct KvPushRouter {
    inner: PushRouter<BackendInput, Annotated<LLMEngineOutput>>,
    chooser: Arc<KvRouter>,
}

impl KvPushRouter {
    pub fn new(
        inner: PushRouter<BackendInput, Annotated<LLMEngineOutput>>,
        chooser: Arc<KvRouter>,
    ) -> Self {
        KvPushRouter { inner, chooser }
    }
}

#[async_trait]
impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Error>
    for KvPushRouter
{
    async fn generate(
        &self,
        request: SingleIn<BackendInput>,
    ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
        match &self.inner.client.endpoints {
            EndpointSource::Static => self.inner.r#static(request).await,
            EndpointSource::Dynamic(_) => {
                let worker_id = self.chooser.find_best_match(&request.token_ids).await?;
                self.inner.direct(request, worker_id).await
            }
        }
    }
}