Commit 19844fc0 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: kv aware router + disagg router + prefill queue (#11)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarhongkuan <hongkuanz@nvidia.com>
Co-authored-by: default avatarPiotr Tarasiewicz <ptarasiewicz@nvidia.com>
Co-authored-by: default avatarPiotr Tarasiewicz Nvidia <ptarasiewicznv@Piotrs-MacBook-Pro.local>
Co-authored-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 7567620f
......@@ -26,7 +26,7 @@ RequestHandler = Callable[[JsonLike], AsyncGenerator[JsonLike, None]]
class DistributedRuntime:
"""
The runtime object for dynemo applications
The runtime object for dynamo applications
"""
...
......@@ -173,6 +173,54 @@ class KvRouter:
"""
...
class DisaggregatedRouter:
"""
A router that determines whether to perform prefill locally or remotely based on
sequence length thresholds.
"""
def __init__(self, drt: DistributedRuntime, model_name: str, default_max_local_prefill_length: int) -> None:
"""
Create a `DisaggregatedRouter` object.
Args:
drt: The distributed runtime instance
model_name: Name of the model
default_max_local_prefill_length: Default maximum sequence length that can be processed locally
"""
...
def prefill_remote(self, prefill_length: int, prefix_hit_length: int) -> bool:
"""
Determine if prefill should be performed remotely based on sequence lengths.
Args:
prefill_length: Total length of the sequence to prefill
prefix_hit_length: Length of the prefix that was already processed
Returns:
True if prefill should be performed remotely, False otherwise
"""
...
def update_value(self, max_local_prefill_length: int) -> None:
"""
Update the maximum local prefill length threshold.
Args:
max_local_prefill_length: New maximum sequence length that can be processed locally
"""
...
def get_model_name(self) -> str:
"""
Get the name of the model associated with this router.
Returns:
The model name as a string
"""
...
class KvMetricsPublisher:
"""
A metrics publisher will provide KV metrics to the router.
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvMetricsPublisher as KvMetricsPublisher
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use tracing;
use dynamo_runtime::transports::etcd::WatchEvent;
use dynamo_runtime::DistributedRuntime;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DisaggRouterConf {
pub max_local_prefill_length: i32,
}
impl Default for DisaggRouterConf {
fn default() -> Self {
Self {
max_local_prefill_length: 1000,
}
}
}
impl DisaggRouterConf {
pub async fn from_etcd_with_watcher(
drt: Arc<DistributedRuntime>,
model_name: &str,
) -> Result<(Self, watch::Receiver<Self>), Box<dyn std::error::Error>> {
let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
// Get the initial value if it exists
let initial_config = match drt.etcd_client().kv_get_prefix(&etcd_key).await {
Ok(kvs) => {
if let Some(kv) = kvs.first() {
match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
Ok(config) => {
tracing::debug!(
"Found initial config for key {}: {:?}",
etcd_key,
config
);
config
}
Err(e) => {
tracing::warn!(
"Failed to parse initial config for key {}: {}",
etcd_key,
e
);
DisaggRouterConf::default()
}
}
} else {
tracing::debug!(
"No initial config found for key {}, using default",
etcd_key
);
DisaggRouterConf::default()
}
}
Err(e) => {
tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
DisaggRouterConf::default()
}
};
// Create watch channel for config updates
let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
// Set up the watcher after getting the initial value
let prefix_watcher = drt.etcd_client().kv_get_and_watch_prefix(&etcd_key).await?;
let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
// Spawn background task to watch for config changes
drt.runtime().secondary().spawn(async move {
tracing::info!("Starting config watcher for disagg router key: {}", key);
loop {
let kv_event = tokio::select! {
_ = watch_tx.closed() => {
tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
break;
}
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => {
tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
break;
}
}
}
};
tracing::debug!("Received watch event for key {}", key);
match kv_event {
WatchEvent::Put(kv) => {
let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
if let Ok(config) = val {
tracing::info!("Config updated for key {}: {:?}", key, config);
// Broadcast the update
if watch_tx.send(config).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
} else {
tracing::error!("Unable to parse router config for key {}", key);
break;
}
}
WatchEvent::Delete(_) => {
tracing::warn!("Config key was deleted: {}", key);
// Reset to default values
if watch_tx.send(DisaggRouterConf::default()).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
}
}
}
tracing::debug!("Completed config watcher for key: {}", key);
});
Ok((initial_config, watch_rx))
}
}
#[derive(Clone)]
pub struct DisaggregatedRouter {
max_local_prefill_length: Arc<Mutex<i32>>,
model_name: String,
config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
}
impl DisaggregatedRouter {
pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
DisaggregatedRouter {
max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
model_name,
config_watcher: None,
}
}
pub async fn new_with_etcd_and_default(
drt: Arc<DistributedRuntime>,
model_name: String,
default_max_local_prefill_length: i32,
) -> Result<Self, Box<dyn std::error::Error>> {
let (mut config, watcher) =
DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
// Use the provided default if no etcd value was found (when config is the default value)
if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
config.max_local_prefill_length = default_max_local_prefill_length;
}
let router = Self {
max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
model_name: model_name.clone(),
config_watcher: Some(watcher),
};
// Start background task to watch for config updates
router.start_config_watcher();
Ok(router)
}
fn start_config_watcher(&self) {
if let Some(watcher) = self.config_watcher.clone() {
let mut watcher = watcher;
// Create a clone for the task
let model_name = self.model_name.clone();
let max_local_prefill_length = self.max_local_prefill_length.clone();
tokio::spawn(async move {
tracing::info!("Starting config update watcher for model: {}", model_name);
while watcher.changed().await.is_ok() {
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
model_name,
old_value,
new_value
);
}
}
tracing::debug!("Config watcher closed for model: {}", model_name);
});
}
}
pub fn check_for_updates(&self) {
if let Some(watcher) = &self.config_watcher {
if watcher.has_changed().unwrap_or(false) {
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = self.max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
self.model_name,
old_value,
new_value
);
}
}
}
}
pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
// Check for updates before making the decision
self.check_for_updates();
// Get the current value from the mutex
let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
// schedule the request purely based on the prefill length
// TODO: apply math models and compare local vs remote prefill TTFT
prefill_length - prefix_hit_length > max_local_prefill_length
}
pub fn update_value(&self, max_local_prefill_length: i32) {
let mut current = self.max_local_prefill_length.lock().unwrap();
*current = max_local_prefill_length;
}
pub fn get_model_name(&self) -> &str {
&self.model_name
}
}
......@@ -72,6 +72,15 @@ impl Endpoint {
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlexibleService {
pub name: String,
pub id: String,
pub version: String,
pub started: String,
pub endpoints: Vec<FlexibleEndpoint>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Service {
pub name: String,
......
......@@ -20,6 +20,7 @@
pub mod backend;
pub mod common;
pub mod disagg_router;
pub mod engines;
pub mod http;
pub mod kv_router;
......
......@@ -27,6 +27,7 @@ dependencies = [
"pytest>=8.3.4",
"pydantic>=2.10.6",
"uvloop>=0.21.0",
"nats-py>=2.6.0",
]
[tool.maturin]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment