kv_router.rs 5.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Result;
17
18
19
20
21
22
23
24
25
26
use dynamo_runtime::{
    component::Component,
    pipeline::{
        async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream,
        SingleIn,
    },
    prelude::*,
    protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
27
use std::sync::Arc;
28
29

pub mod indexer;
30
pub mod metrics_aggregator;
31
32
pub mod protocols;
pub mod publisher;
33
pub mod recorder;
34
35
pub mod scheduler;
pub mod scoring;
36

37
38
39
40
41
42
43
44
45
use crate::{
    kv_router::{
        indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
        metrics_aggregator::KvMetricsAggregator,
        protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
        scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
        scoring::ProcessedEndpoints,
    },
    tokens::Tokens,
46
47
};

48
use dynamo_runtime::traits::events::EventSubscriber;
49
50
51

// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
52
pub const KV_EVENT_SUBJECT: &str = "kv_events";
53
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
54
pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
55

56
57
58
59
60
61
62
63
64
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
        block_size: usize,
    ) -> Result<WorkerSelectionResult, KvSchedulerError>;
}
65

66
pub struct KvRouter {
67
    indexer: KvIndexer,
68
69
    scheduler: KvScheduler,
    block_size: usize,
70
71
72
73
}

impl KvRouter {
    pub async fn new(
74
        component: Component,
75
76
        block_size: usize,
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
77
    ) -> Result<Arc<Self>> {
78
79
80
81
82
        let cancellation_token = component
            .drt()
            .primary_lease()
            .expect("Cannot KV route static workers")
            .primary_token();
83
84
85
86
87
88
89
90
91
92
93

        let metrics_aggregator =
            KvMetricsAggregator::new(component.clone(), cancellation_token.clone()).await;
        let indexer = KvIndexer::new(cancellation_token.clone(), block_size);
        let scheduler = KvScheduler::start(
            component.namespace().clone(),
            block_size,
            metrics_aggregator.endpoints_watcher(),
            selector,
        )
        .await?;
94

95
96
97
        // [gluo TODO] try subscribe_with_type::<RouterEvent>,
        // error checking below will be different.
        let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
98
99
100
101
        let kv_events_tx = indexer.event_sender();

        tokio::spawn(async move {
            while let Some(event) = kv_events_rx.next().await {
Alec's avatar
Alec committed
102
103
104
105
106
107
108
109
110
111
112
113
                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                    Ok(event) => {
                        tracing::debug!("received kv event: {:?}", event);
                        event
                    }
                    Err(e) => {
                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                        // Choosing warn and continue to process other events from other workers
                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                        continue;
                    }
                };
114
                if let Err(e) = kv_events_tx.send(event).await {
Alec's avatar
Alec committed
115
                    tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
116
117
118
119
120
121
122
                }
            }
        });

        Ok(Arc::new(Self {
            scheduler,
            indexer,
123
            block_size,
124
125
126
127
        }))
    }

    // [TODO] indexer needs to take 'lora_id' as parameter
GuanLuo's avatar
GuanLuo committed
128
    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
129
130
131
132
133
134
135
        // Extracting part of the code in KvRouter::generate() for only
        // the decision making part, routing is done by the caller
        let isl_tokens = token_ids.len();
        let overlap_scores = self
            .indexer
            .find_matches_for_request(token_ids.as_slice())
            .await?;
Alec's avatar
Alec committed
136
        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
GuanLuo's avatar
GuanLuo committed
137
138
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
        Ok(worker_id)
139
140
    }
}
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

#[async_trait]
impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Error> for KvRouter {
    async fn generate(
        &self,
        request: SingleIn<RouterRequest>,
    ) -> Result<ManyOut<Annotated<RouterResponse>>> {
        let (request, ctx) = request.into_parts();
        let isl_tokens = request.tokens.len();
        let block_size = self.block_size;

        // Compute the block hashes in a blocking task
        let local_block_hashes: Vec<LocalBlockHash> = tokio::task::spawn_blocking(move || {
            Tokens::compute_block_hash(&request.tokens, block_size)
                .into_iter()
                .map(LocalBlockHash)
                .collect()
        })
        .await?;

        let overlap_scores = self.indexer.find_matches(local_block_hashes).await?;
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;

        let response = RouterResponse { worker_id };
        let response = Annotated::from_data(response);
        let stream = stream::iter(vec![response]);
        Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
    }
}