"container/deps/vllm/vllm_v0.7.2-dynemo-kv-disagg-patch.patch" did not exist on "fe83f8aa3c96e238ef97275d5fec94b216d26743"
kv_router.rs 7.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Result;
Neelay Shah's avatar
Neelay Shah committed
17
use dynamo_runtime::{component::Component, component::Namespace, DistributedRuntime};
18
19
20
use futures::stream::StreamExt;
use std::{sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
Alec's avatar
Alec committed
21
use tracing;
22
23

pub mod indexer;
24
pub mod metrics_aggregator;
25
26
pub mod protocols;
pub mod publisher;
27
28
pub mod scheduler;
pub mod scoring;
29
30
31
32
33
34
35
36
37

use crate::kv_router::{
    indexer::{KvIndexer, KvIndexerInterface, RouterEvent},
    scheduler::{Endpoint, KvScheduler, Service},
    scoring::ProcessedEndpoints,
};

// this should be discovered from the backend
pub const KV_EVENT_SUBJECT: &str = "kv_events";
38
pub const KV_HIT_RATE_SUBJECT: &str = "kv-hit-rate";
39
40
41
42
43
44
45
46

pub struct KvRouter {
    // properties of request plane
    // maybe rolled up into the generic object or not
    service_name: String,

    cancellation_token: CancellationToken,

47
    #[allow(dead_code)]
48
49
50
51
52
53
54
55
56
57
58
59
60
    scheduler: KvScheduler,

    indexer: KvIndexer,
}

impl KvRouter {
    pub async fn from_runtime(
        runtime: DistributedRuntime,
        backend: Component,
    ) -> Result<Arc<Self>> {
        let nats_client = runtime.nats_client();
        let service_name = backend.service_name();
        let kv_subject = backend.event_subject(KV_EVENT_SUBJECT);
61
62
63
        let namespace = runtime.namespace(backend.namespace())?;

        tracing::info!("Component Namespace {}", backend.namespace());
Alec's avatar
Alec committed
64
65
        tracing::info!("Component Service Name {}", service_name);
        tracing::info!("KV Subject {}", kv_subject);
66
        Self::new(nats_client, service_name, kv_subject, namespace).await
67
68
69
    }

    pub async fn new(
Neelay Shah's avatar
Neelay Shah committed
70
        nats_client: dynamo_runtime::transports::nats::Client,
71
72
        service_name: String,
        kv_subject: String,
73
        namespace: Namespace,
74
75
76
77
78
79
80
81
82
83
84
85
    ) -> Result<Arc<Self>> {
        let cancellation_token = CancellationToken::new();
        let (ep_tx, ep_rx) = tokio::sync::mpsc::channel(128);

        tokio::spawn(collect_endpoints(
            nats_client.clone(),
            service_name.clone(),
            ep_tx,
            cancellation_token.clone(),
        ));

        let indexer = KvIndexer::new(cancellation_token.clone());
86
        let scheduler = KvScheduler::start(ep_rx, namespace).await?;
87

Alec's avatar
Alec committed
88
        tracing::debug!("subscribing to kv events: {}", kv_subject);
89
90
91
92
93
        let mut kv_events_rx = nats_client.client().subscribe(kv_subject).await?;
        let kv_events_tx = indexer.event_sender();

        tokio::spawn(async move {
            while let Some(event) = kv_events_rx.next().await {
Alec's avatar
Alec committed
94
95
96
97
98
99
100
101
102
103
104
105
                let event: RouterEvent = match serde_json::from_slice(&event.payload) {
                    Ok(event) => {
                        tracing::debug!("received kv event: {:?}", event);
                        event
                    }
                    Err(e) => {
                        tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
                        // Choosing warn and continue to process other events from other workers
                        // A bad event likely signals a problem with a worker, but potentially other workers are still healthy
                        continue;
                    }
                };
106
                if let Err(e) = kv_events_tx.send(event).await {
Alec's avatar
Alec committed
107
                    tracing::trace!("failed to send kv event to indexer; shutting down: {:?}", e);
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                }
            }
        });

        Ok(Arc::new(Self {
            service_name,
            cancellation_token,
            scheduler,
            indexer,
        }))
    }

    pub fn cancellation_token(&self) -> CancellationToken {
        self.cancellation_token.clone()
    }

    pub fn service_name(&self) -> &str {
        &self.service_name
    }

    // [TODO] indexer needs to take 'lora_id' as parameter
GuanLuo's avatar
GuanLuo committed
129
    pub async fn schedule(&self, token_ids: &Vec<u32>, _lora_id: u64) -> Result<i64> {
130
131
132
133
134
135
136
        // Extracting part of the code in KvRouter::generate() for only
        // the decision making part, routing is done by the caller
        let isl_tokens = token_ids.len();
        let overlap_scores = self
            .indexer
            .find_matches_for_request(token_ids.as_slice())
            .await?;
Alec's avatar
Alec committed
137
        tracing::debug!("KV router overlap_scores: {:?}", overlap_scores);
GuanLuo's avatar
GuanLuo committed
138
139
        let worker_id = self.scheduler.schedule(overlap_scores, isl_tokens).await?;
        Ok(worker_id)
140
141
142
143
    }
}

async fn collect_endpoints(
Neelay Shah's avatar
Neelay Shah committed
144
    nats_client: dynamo_runtime::transports::nats::Client,
145
146
147
148
149
150
151
    service_name: String,
    ep_tx: tokio::sync::mpsc::Sender<ProcessedEndpoints>,
    cancel: CancellationToken,
) {
    loop {
        tokio::select! {
            _ = cancel.cancelled() => {
Alec's avatar
Alec committed
152
                tracing::debug!("cancellation token triggered");
153
154
155
                break;
            }
            _ = tokio::time::sleep(Duration::from_secs(1)) => {
Alec's avatar
Alec committed
156
                tracing::trace!("collecting endpoints for service: {}", service_name);
157
158
159
            }
        }

Alec's avatar
Alec committed
160
        let values = match nats_client
161
162
            .get_endpoints(&service_name, Duration::from_secs(1))
            .await
Alec's avatar
Alec committed
163
164
165
166
167
168
169
        {
            Ok(v) => v,
            Err(e) => {
                tracing::warn!("Failed to retrieve endpoints for {}: {:?}", service_name, e);
                continue;
            }
        };
170

Alec's avatar
Alec committed
171
        tracing::debug!("values: {:?}", values);
172
173
174
        let services: Vec<Service> = values
            .into_iter()
            .filter(|v| !v.is_empty())
Alec's avatar
Alec committed
175
176
177
178
179
180
            .filter_map(|v| match serde_json::from_slice::<Service>(&v) {
                Ok(service) => Some(service),
                Err(e) => {
                    tracing::warn!("For value: {:?} \nFailed to parse service: {:?}", v, e);
                    None
                }
181
182
            })
            .collect();
Alec's avatar
Alec committed
183
        tracing::debug!("services: {:?}", services);
184

Alec's avatar
Alec committed
185
186
187
188
189
190
191
192
193
194
195
        let endpoints: Vec<Endpoint> = services
            .into_iter()
            .flat_map(|s| s.endpoints)
            .filter(|s| s.data.is_some())
            .map(|s| Endpoint {
                name: s.name,
                subject: s.subject,
                data: s.data.unwrap(),
            })
            .collect();
        tracing::debug!("endpoints: {:?}", endpoints);
196

Alec's avatar
Alec committed
197
        tracing::trace!(
198
199
200
201
202
203
204
205
206
            "found {} endpoints for service: {}",
            endpoints.len(),
            service_name
        );

        let processed = ProcessedEndpoints::new(endpoints);

        // process endpoints into
        if ep_tx.send(processed).await.is_err() {
Alec's avatar
Alec committed
207
            tracing::trace!("failed to send processed endpoints; shutting down");
208
209
210
211
            break;
        }
    }
}