kv_router.rs 5.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Result;
Neelay Shah's avatar
Neelay Shah committed
17
use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime};
18
use futures::stream::StreamExt;
19
use std::sync::Arc;
20
use tokio_util::sync::CancellationToken;
Alec's avatar
Alec committed
21
use tracing;
22
23

pub mod indexer;
24
pub mod metrics_aggregator;
25
26
pub mod protocols;
pub mod publisher;
27
28
pub mod scheduler;
pub mod scoring;
29
30
31

use crate::kv_router::{
    indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
32
33
    metrics_aggregator::collect_endpoints,
    scheduler::KvScheduler,
34
35
36
37
38
    scoring::ProcessedEndpoints,
};

// this should be discovered from the backend
pub const KV_EVENT_SUBJECT: &str = "kv_events";
39
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
40
41
42
43
44
45
46
47

pub struct KvRouter {
    // properties of request plane
    // maybe rolled up into the generic object or not
    service_name: String,

    cancellation_token: CancellationToken,

48
    #[allow(dead_code)]
49
50
51
52
53
54
55
56
57
    scheduler: KvScheduler,

    indexer: KvIndexer,
}

impl KvRouter {
    pub async fn from_runtime(
        runtime: DistributedRuntime,
        backend: Component,
58
        kv_block_size: usize,
59
60
61
62
    ) -> Result<Arc<Self>> {
        let nats_client = runtime.nats_client();
        let service_name = backend.service_name();
        let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
63
64
65
        let namespace = runtime.namespace(backend.namespace())?;

        tracing::info!("Component Namespace {}", backend.namespace());
Alec's avatar
Alec committed
66
67
        tracing::info!("Component Service Name {}", service_name);
        tracing::info!("KV Subject {}", kv_subject);
68
69
70
71
72
73
74
75
        Self::new(
            nats_client,
            service_name,
            kv_subject,
            namespace,
            kv_block_size,
        )
        .await
76
77
78
    }

    pub async fn new(
Neelay Shah's avatar
Neelay Shah committed
79
        nats_client: dynamo_runtime::transports::nats::Client,
80
81
        service_name: String,
        kv_subject: String,
82
        namespace: Namespace,
83
        kv_block_size: usize,
84
85
86
87
88
89
90
91
92
93
94
    ) -> Result<Arc<Self>> {
        let cancellation_token = CancellationToken::new();
        let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);

        tokio::spawn(collect_endpoints(
            nats_client.clone(),
            service_name.clone(),
            ep_tx,
            cancellation_token.clone(),
        ));

95
96
        let indexer = KvIndexer::new(cancellation_token.clone(), kv_block_size);
        let scheduler = KvScheduler::start(ep_rx, namespace, kv_block_size).await?;
97

Alec's avatar
Alec committed
98
        tracing::debug!("subscribing to kv events: {}", kv_subject);
99
100
101
102
103
        let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
        let kv_events_tx = indexer.event_sender();

        tokio::spawn(async move {
            while let Some(event) = kv_events_rx.next().await {
Alec's avatar
Alec committed
104
105
106
107
108
109
110
111
112
113
114
115
                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                    Ok(event) => {
                        tracing::debug!("received kv event: {:?}", event);
                        event
                    }
                    Err(e) => {
                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                        // Choosing warn and continue to process other events from other workers
                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                        continue;
                    }
                };
116
                if let Err(e) = kv_events_tx.send(event).await {
Alec's avatar
Alec committed
117
                    tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
                }
            }
        });

        Ok(Arc::new(Self {
            service_name,
            cancellation_token,
            scheduler,
            indexer,
        }))
    }

    pub fn cancellation_token(&self) -> CancellationToken {
        self.cancellation_token.clone()
    }

    pub fn service_name(&self) -> &str {
        &self.service_name
    }

    // [TODO] indexer needs to take 'lora_id' as parameter
GuanLuo's avatar
GuanLuo committed
139
    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
140
141
142
143
144
145
146
        // Extracting part of the code in KvRouter::generate() for only
        // the decision making part, routing is done by the caller
        let isl_tokens = token_ids.len();
        let overlap_scores = self
            .indexer
            .find_matches_for_request(token_ids.as_slice())
            .await?;
Alec's avatar
Alec committed
147
        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
GuanLuo's avatar
GuanLuo committed
148
149
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
        Ok(worker_id)
150
151
    }
}