"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "feb6d272ac7dc392bdbd4a759c69f310bb3a78ba"
Unverified Commit dadf0e22 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Remove old DisaggregatedRouter, making etcd presence optional (#4011)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 30d85883
...@@ -150,7 +150,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -150,7 +150,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Endpoint>()?; m.add_class::<Endpoint>()?;
m.add_class::<Client>()?; m.add_class::<Client>()?;
m.add_class::<AsyncResponseStream>()?; m.add_class::<AsyncResponseStream>()?;
m.add_class::<llm::disagg_router::DisaggregatedRouter>()?;
m.add_class::<llm::entrypoint::EntrypointArgs>()?; m.add_class::<llm::entrypoint::EntrypointArgs>()?;
m.add_class::<llm::entrypoint::EngineConfig>()?; m.add_class::<llm::entrypoint::EngineConfig>()?;
m.add_class::<llm::entrypoint::EngineType>()?; m.add_class::<llm::entrypoint::EngineType>()?;
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
use super::*; use super::*;
pub mod backend; pub mod backend;
pub mod disagg_router;
pub mod entrypoint; pub mod entrypoint;
pub mod kv; pub mod kv;
pub mod local_model; pub mod local_model;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::*;
use pyo3::exceptions::PyRuntimeError;
use std::sync::Arc;
use tokio::runtime::Runtime;
#[pyclass]
pub struct DisaggregatedRouter {
inner: Arc<dynamo_llm::disagg_router::DisaggregatedRouter>,
}
#[pymethods]
impl DisaggregatedRouter {
#[new]
#[pyo3(signature = (drt, model_name, default_max_local_prefill_length))]
fn new(
drt: PyObject,
model_name: String,
default_max_local_prefill_length: i32,
) -> PyResult<Self> {
let drt_arc = Python::with_gil(|py| {
let drt_ref = drt.extract::<DistributedRuntime>(py)?;
Ok::<_, PyErr>(Arc::new(drt_ref.inner))
})?;
// Create the runtime directly with the correct import
let runtime = Runtime::new().map_err(|e| {
PyRuntimeError::new_err(format!("Failed to create tokio runtime: {}", e))
})?;
let router = runtime.block_on(async {
dynamo_llm::disagg_router::DisaggregatedRouter::new_with_etcd_and_default(
drt_arc,
model_name,
default_max_local_prefill_length,
)
.await
.map_err(|e| {
PyRuntimeError::new_err(format!("Failed to create DisaggregatedRouter: {}", e))
})
})?;
Ok(DisaggregatedRouter {
inner: Arc::new(router),
})
}
fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
self.inner.prefill_remote(prefill_length, prefix_hit_length)
}
fn get_model_name(&self) -> &str {
self.inner.get_model_name()
}
}
...@@ -217,58 +217,6 @@ class Client: ...@@ -217,58 +217,6 @@ class Client:
""" """
... ...
class DisaggregatedRouter:
"""
A router that determines whether to perform prefill locally or remotely based on
sequence length thresholds.
"""
def __init__(
self,
drt: DistributedRuntime,
model_name: str,
default_max_local_prefill_length: int,
) -> None:
"""
Create a `DisaggregatedRouter` object.
Args:
drt: The distributed runtime instance
model_name: Name of the model
default_max_local_prefill_length: Default maximum sequence length that can be processed locally
"""
...
def prefill_remote(self, prefill_length: int, prefix_hit_length: int) -> bool:
"""
Determine if prefill should be performed remotely based on sequence lengths.
Args:
prefill_length: Total length of the sequence to prefill
prefix_hit_length: Length of the prefix that was already processed
Returns:
True if prefill should be performed remotely, False otherwise
"""
...
def update_value(self, max_local_prefill_length: int) -> None:
"""
Update the maximum local prefill length threshold.
Args:
max_local_prefill_length: New maximum sequence length that can be processed locally
"""
...
def get_model_name(self) -> str:
"""
Get the name of the model associated with this router.
Returns:
The model name as a string
"""
...
def compute_block_hash_for_seq_py(tokens: List[int], kv_block_size: int) -> List[int]: def compute_block_hash_for_seq_py(tokens: List[int], kv_block_size: int) -> List[int]:
""" """
......
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
import logging import logging
from dynamo._core import ApproxKvIndexer as ApproxKvIndexer from dynamo._core import ApproxKvIndexer as ApproxKvIndexer
from dynamo._core import DisaggregatedRouter as DisaggregatedRouter
from dynamo._core import EngineType from dynamo._core import EngineType
from dynamo._core import EntrypointArgs as EntrypointArgs from dynamo._core import EntrypointArgs as EntrypointArgs
from dynamo._core import ForwardPassMetrics as ForwardPassMetrics from dynamo._core import ForwardPassMetrics as ForwardPassMetrics
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use tracing;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::transports::etcd::WatchEvent;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DisaggRouterConf {
pub max_local_prefill_length: i32,
}
impl Default for DisaggRouterConf {
fn default() -> Self {
Self {
max_local_prefill_length: 1000,
}
}
}
impl DisaggRouterConf {
pub async fn from_etcd_with_watcher(
drt: Arc<DistributedRuntime>,
model_name: &str,
) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
// Get the initial value if it exists
let Some(etcd_client) = drt.etcd_client() else {
anyhow::bail!("Static components don't have an etcd client");
};
let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
Ok(kvs) => {
if let Some(kv) = kvs.first() {
match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
Ok(config) => {
tracing::debug!(
"Found initial config for key {}: {:?}",
etcd_key,
config
);
config
}
Err(e) => {
tracing::warn!(
"Failed to parse initial config for key {}: {}",
etcd_key,
e
);
DisaggRouterConf::default()
}
}
} else {
tracing::debug!(
"No initial config found for key {}, using default",
etcd_key
);
DisaggRouterConf::default()
}
}
Err(e) => {
tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
DisaggRouterConf::default()
}
};
// Create watch channel for config updates
let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
// Set up the watcher after getting the initial value
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
let (key, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
// Spawn background task to watch for config changes
drt.runtime().secondary().spawn(async move {
tracing::info!("Starting config watcher for disagg router key: {}", key);
loop {
let kv_event = tokio::select! {
_ = watch_tx.closed() => {
tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
break;
}
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => {
tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
break;
}
}
}
};
tracing::debug!("Received watch event for key {}", key);
match kv_event {
WatchEvent::Put(kv) => {
let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
if let Ok(config) = val {
tracing::info!("Config updated for key {}: {:?}", key, config);
// Broadcast the update
if watch_tx.send(config).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
} else {
tracing::error!("Unable to parse router config for key {}", key);
break;
}
}
WatchEvent::Delete(_) => {
tracing::warn!("Config key was deleted: {}", key);
// Reset to default values
if watch_tx.send(DisaggRouterConf::default()).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
}
}
}
tracing::debug!("Completed config watcher for key: {}", key);
});
Ok((initial_config, watch_rx))
}
}
#[derive(Clone)]
pub struct DisaggregatedRouter {
max_local_prefill_length: Arc<Mutex<i32>>,
model_name: String,
config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
}
impl DisaggregatedRouter {
pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
DisaggregatedRouter {
max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
model_name,
config_watcher: None,
}
}
pub async fn new_with_etcd_and_default(
drt: Arc<DistributedRuntime>,
model_name: String,
default_max_local_prefill_length: i32,
) -> anyhow::Result<Self> {
let (mut config, watcher) =
DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
// Use the provided default if no etcd value was found (when config is the default value)
if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
config.max_local_prefill_length = default_max_local_prefill_length;
}
let router = Self {
max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
model_name: model_name.clone(),
config_watcher: Some(watcher),
};
// Start background task to watch for config updates
router.start_config_watcher();
Ok(router)
}
fn start_config_watcher(&self) {
if let Some(watcher) = self.config_watcher.clone() {
let mut watcher = watcher;
// Create a clone for the task
let model_name = self.model_name.clone();
let max_local_prefill_length = self.max_local_prefill_length.clone();
tokio::spawn(async move {
tracing::info!("Starting config update watcher for model: {}", model_name);
while watcher.changed().await.is_ok() {
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
model_name,
old_value,
new_value
);
}
}
tracing::debug!("Config watcher closed for model: {}", model_name);
});
}
}
pub fn check_for_updates(&self) {
if let Some(watcher) = &self.config_watcher
&& watcher.has_changed().unwrap_or(false)
{
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = self.max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
self.model_name,
old_value,
new_value
);
}
}
}
pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
// Check for updates before making the decision
self.check_for_updates();
// Get the current value from the mutex
let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
// schedule the request purely based on the prefill length
// TODO: apply math models and compare local vs remote prefill TTFT
prefill_length - prefix_hit_length > max_local_prefill_length
}
pub fn update_value(&self, max_local_prefill_length: i32) {
let mut current = self.max_local_prefill_length.lock().unwrap();
*current = max_local_prefill_length;
}
pub fn get_model_name(&self) -> &str {
&self.model_name
}
}
...@@ -12,7 +12,6 @@ use anyhow::Context as _; ...@@ -12,7 +12,6 @@ use anyhow::Context as _;
pub mod backend; pub mod backend;
pub mod common; pub mod common;
pub mod disagg_router;
pub mod discovery; pub mod discovery;
pub mod endpoint_type; pub mod endpoint_type;
pub mod engines; pub mod engines;
......
...@@ -52,12 +52,20 @@ impl DistributedRuntime { ...@@ -52,12 +52,20 @@ impl DistributedRuntime {
let runtime_clone = runtime.clone(); let runtime_clone = runtime.clone();
// TODO: Here is where we will later select the KeyValueStore impl
let (etcd_client, store) = if is_static { let (etcd_client, store) = if is_static {
(None, KeyValueStoreManager::memory()) (None, KeyValueStoreManager::memory())
} else { } else {
let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?; match etcd::Client::new(etcd_config.clone(), runtime_clone).await {
let store = KeyValueStoreManager::etcd(etcd_client.clone()); Ok(etcd_client) => {
(Some(etcd_client), store) let store = KeyValueStoreManager::etcd(etcd_client.clone());
(Some(etcd_client), store)
}
Err(err) => {
tracing::info!(%err, "Did not connect to etcd. Using memory storage.");
(None, KeyValueStoreManager::memory())
}
}
}; };
let nats_client = Some(nats_config.clone().connect().await?); let nats_client = Some(nats_config.clone().connect().await?);
...@@ -266,16 +274,13 @@ impl DistributedRuntime { ...@@ -266,16 +274,13 @@ impl DistributedRuntime {
} }
// todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers // todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers
//
// Try to use `store()` instead of this. Only use this if you have not been able to migrate
// yet, or if you require etcd-specific features like distributed locking (rare).
pub fn etcd_client(&self) -> Option<etcd::Client> { pub fn etcd_client(&self) -> Option<etcd::Client> {
self.etcd_client.clone() self.etcd_client.clone()
} }
// Deprecated but our CI blocks us using the feature currently.
//#[deprecated(note = "Use KeyValueStoreManager via store(); this will be removed")]
pub fn deprecated_etcd_client(&self) -> Option<etcd::Client> {
self.etcd_client.clone()
}
/// An interface to store things. Will eventually replace `etcd_client`. /// An interface to store things. Will eventually replace `etcd_client`.
/// Currently does key-value, but will grow to include whatever we need to store. /// Currently does key-value, but will grow to include whatever we need to store.
pub fn store(&self) -> &KeyValueStoreManager { pub fn store(&self) -> &KeyValueStoreManager {
......
...@@ -305,7 +305,6 @@ impl KeyValueStoreManager { ...@@ -305,7 +305,6 @@ impl KeyValueStoreManager {
} }
/// An online storage for key-value config values. /// An online storage for key-value config values.
/// Usually backed by `nats-server`.
#[async_trait] #[async_trait]
pub trait KeyValueBucket: Send + Sync { pub trait KeyValueBucket: Send + Sync {
/// A bucket is a collection of key/value pairs. /// A bucket is a collection of key/value pairs.
......
...@@ -151,7 +151,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali ...@@ -151,7 +151,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
data: &LeaderData, data: &LeaderData,
) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> { ) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let etcd_client = rt let etcd_client = rt
.deprecated_etcd_client() .etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?; .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
let lease_id = etcd_client.lease_id(); let lease_id = etcd_client.lease_id();
...@@ -245,7 +245,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali ...@@ -245,7 +245,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
data: &WorkerData, data: &WorkerData,
) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> { ) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
let etcd_client = rt let etcd_client = rt
.deprecated_etcd_client() .etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?; .ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
let lease_id = etcd_client.lease_id(); let lease_id = etcd_client.lease_id();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment