Unverified Commit 08fd2897 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: ETCD prefix watcher + python binding + runtime reconfiguration for...


feat: ETCD prefix watcher + python binding + runtime reconfiguration for router and disagg router  (#581)
Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 1c67be28
......@@ -15,6 +15,9 @@
import logging
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import dynamo_context
logger = logging.getLogger(__name__)
......@@ -31,16 +34,33 @@ class PyDisaggregatedRouter:
self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size
def prefill_remote(
async def async_init(self):
runtime = dynamo_context["runtime"]
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/disagg_router/",
{
"max_local_prefill_length": str(self.max_local_prefill_length),
"max_prefill_queue_size": str(self.max_prefill_queue_size),
},
)
async def prefill_remote(
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
):
max_local_prefill_length = int(
await self.etcd_kv_cache.get("max_local_prefill_length")
)
max_prefill_queue_size = int(
await self.etcd_kv_cache.get("max_prefill_queue_size")
)
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
# TODO: consider size of each request in the queue when making the decision
decision = (
absolute_prefill_length > self.max_local_prefill_length
and queue_size < self.max_prefill_queue_size
absolute_prefill_length > max_local_prefill_length
and queue_size < max_prefill_queue_size
)
logger.info(
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{self.max_local_prefill_length}, prefill queue size: {queue_size}/{self.max_prefill_queue_size})"
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{max_local_prefill_length}, prefill queue size: {queue_size}/{max_prefill_queue_size})"
)
return decision
......@@ -30,6 +30,7 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRe
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
......@@ -65,7 +66,6 @@ class Processor(ProcessMixIn):
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.router_mode = self.engine_args.router
self.min_workers = 1
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
......@@ -95,6 +95,12 @@ class Processor(ProcessMixIn):
await check_required_workers(self.worker_client, self.min_workers)
self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
"/dynamo/processor/",
{"router": self.engine_args.router},
)
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
......@@ -109,7 +115,8 @@ class Processor(ProcessMixIn):
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
if self.router_mode == "kv":
router_mode = (await self.etcd_kv_cache.get("router")).decode()
if router_mode == "kv":
async for route_response in self.router.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
):
......@@ -139,7 +146,7 @@ class Processor(ProcessMixIn):
).model_dump_json(),
int(worker_id),
)
elif self.router_mode == "random":
elif router_mode == "random":
engine_generator = await self.worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
......@@ -147,7 +154,7 @@ class Processor(ProcessMixIn):
request_id=request_id,
).model_dump_json()
)
elif self.router_mode == "round-robin":
elif router_mode == "round-robin":
engine_generator = await self.worker_client.round_robin(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
......
......@@ -135,6 +135,7 @@ class VllmWorker:
max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
)
await self.disaggregated_router.async_init()
else:
self.disaggregated_router = None
logger.info("VllmWorker has been initialized")
......@@ -164,7 +165,7 @@ class VllmWorker:
stream_name=self._prefill_queue_stream_name,
) as prefill_queue:
prefill_queue_size = await prefill_queue.get_queue_size()
disagg_router_decision = self.disaggregated_router.prefill_remote(
disagg_router_decision = await self.disaggregated_router.prefill_remote(
len(request.engine_prompt["prompt_token_ids"]),
request.prefix_hit_rate,
prefill_queue_size,
......
......@@ -76,6 +76,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<EtcdKvCache>()?;
engine::add_to_module(m)?;
......@@ -96,6 +97,12 @@ fn log_message(level: &str, message: &str, module: &str, file: &str, line: u32)
logging::log_message(level, message, module, file, line);
}
#[pyclass]
#[derive(Clone)]
struct EtcdKvCache {
inner: Arc<rs::transports::etcd::KvCache>,
}
#[pyclass]
#[derive(Clone)]
struct DistributedRuntime {
......@@ -205,6 +212,118 @@ impl DistributedRuntime {
}
}
#[pymethods]
impl EtcdKvCache {
#[new]
fn py_new(
_etcd_client: &EtcdClient,
_prefix: String,
_initial_values: &Bound<'_, PyDict>,
) -> PyResult<Self> {
// We can't create the KvCache here because it's async, so we'll return an error
Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"EtcdKvCache must be created using the 'new' class method",
))
}
#[staticmethod]
#[allow(clippy::new_ret_no_self)]
fn create<'p>(
py: Python<'p>,
etcd_client: &EtcdClient,
prefix: String,
initial_values: &Bound<'p, PyDict>,
) -> PyResult<Bound<'p, PyAny>> {
let client = etcd_client.inner.clone();
// Convert Python dict to Rust HashMap
let mut rust_initial_values = std::collections::HashMap::new();
for (key, value) in initial_values.iter() {
let key_str = key.extract::<String>()?;
// Handle both string and bytes values
let value_bytes = if let Ok(bytes) = value.extract::<Vec<u8>>() {
bytes
} else if let Ok(string) = value.extract::<String>() {
string.into_bytes()
} else {
return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
"Values must be either strings or bytes",
));
};
rust_initial_values.insert(key_str, value_bytes);
}
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let kv_cache = rs::transports::etcd::KvCache::new(client, prefix, rust_initial_values)
.await
.map_err(to_pyerr)?;
Ok(EtcdKvCache {
inner: Arc::new(kv_cache),
})
})
}
fn get<'p>(&self, py: Python<'p>, key: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(value) = inner.get(&key).await {
match Python::with_gil(|py| {
let py_obj = PyBytes::new(py, &value).into_pyobject(py)?;
Ok(py_obj.unbind().into_any())
}) {
Ok(result) => Ok(result),
Err(e) => Err(e),
}
} else {
Ok(Python::with_gil(|py| py.None()))
}
})
}
fn get_all<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let all_values = inner.get_all().await;
Python::with_gil(|py| {
let dict = PyDict::new(py);
for (key, value) in all_values {
// Strip the prefix from the key
let stripped_key = if let Some(stripped) = key.strip_prefix(&inner.prefix) {
stripped.to_string()
} else {
key
};
dict.set_item(stripped_key, PyBytes::new(py, &value))?;
}
let py_obj = dict.into_pyobject(py)?;
Ok(py_obj.unbind().into_any())
})
})
}
#[pyo3(signature = (key, value, lease_id=None))]
fn put<'p>(
&self,
py: Python<'p>,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.put(&key, value, lease_id).await.map_err(to_pyerr)?;
Ok(())
})
}
}
#[pymethods]
impl CancellationToken {
fn cancel(&self) {
......
......@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional
from typing import AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional, Union
def log_message(level: str, message: str, module: str, file: str, line: int) -> None:
"""
......@@ -76,6 +76,73 @@ class EtcdClient:
"""
...
class EtcdKvCache:
"""
A cache for key-value pairs stored in etcd.
"""
@staticmethod
async def new(
etcd_client: EtcdClient,
prefix: str,
initial_values: Dict[str, Union[str, bytes]]
) -> "EtcdKvCache":
"""
Create a new EtcdKvCache instance.
Args:
etcd_client: The etcd client to use for operations
prefix: The prefix to use for all keys in this cache.
EtcdKvCache will continuously watch the changes of the keys under this prefix.
initial_values: Initial key-value pairs to populate the cache with
NOTE: if the key already exists, it won't be updated
Returns:
A new EtcdKvCache instance
"""
...
async def get(self, key: str) -> Optional[bytes]:
"""
Get a value from the cache.
Args:
key: The key to retrieve
Returns:
The value as bytes if found, None otherwise
NOTE: this get is cheap because internally there is a cache that holds the latest kv pairs.
To prevent race condition, there is a lock when reading/writing the internal cache.
"""
...
async def get_all(self) -> Dict[str, bytes]:
"""
Get all key-value pairs from the cache.
Returns:
A dictionary of all key-value pairs, with keys stripped of the prefix
(i.e., in the same format as in `initial_values`.keys())
"""
...
async def put(
self,
key: str,
value: bytes,
lease_id: Optional[int] = None
) -> None:
"""
Put a key-value pair into the cache and etcd.
Args:
key: The key to store
value: The value to store
lease_id: Optional lease ID to associate with this key-value pair
"""
...
class Namespace:
"""
A namespace is a collection of components
......
......@@ -26,6 +26,7 @@ from dynamo._core import Backend as Backend
from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
......
......@@ -19,7 +19,9 @@ use async_nats::jetstream::kv;
use derive_builder::Builder;
use derive_getters::Dissolve;
use futures::StreamExt;
use tokio::sync::mpsc;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use validator::Validate;
use etcd_client::{
......@@ -375,6 +377,129 @@ fn default_servers() -> Vec<String> {
}
}
/// A cache for etcd key-value pairs that watches for changes
pub struct KvCache {
client: Client,
pub prefix: String,
cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
watcher: Option<PrefixWatcher>,
}
impl KvCache {
/// Create a new KV cache for the given prefix
pub async fn new(
client: Client,
prefix: String,
initial_values: HashMap<String, Vec<u8>>,
) -> Result<Self> {
let mut cache = HashMap::new();
// First get all existing keys with this prefix
let existing_kvs = client.kv_get_prefix(&prefix).await?;
for kv in existing_kvs {
let key = String::from_utf8_lossy(kv.key()).to_string();
cache.insert(key, kv.value().to_vec());
}
// For any keys in initial_values that don't exist in etcd, write them
// TODO: proper lease handling, this requires the first process that write to a prefix atomically
// create a lease and write the lease to etcd. Later processes will attach to the lease and
// help refresh the lease.
for (key, value) in initial_values.iter() {
let full_key = format!("{}{}", prefix, key);
if let std::collections::hash_map::Entry::Vacant(e) = cache.entry(full_key.clone()) {
client.kv_put(&full_key, value.clone(), None).await?;
e.insert(value.clone());
}
}
// Start watching for changes
// we won't miss events bewteen the initial push and the watcher starting because
// client.kv_get_and_watch_prefix() will get all kv pairs and put them back again
let watcher = client.kv_get_and_watch_prefix(&prefix).await?;
let cache = Arc::new(RwLock::new(cache));
let mut result = Self {
client,
prefix,
cache,
watcher: Some(watcher),
};
// Start the background watcher task
result.start_watcher().await?;
Ok(result)
}
/// Start the background watcher task
async fn start_watcher(&mut self) -> Result<()> {
if let Some(watcher) = self.watcher.take() {
let cache = self.cache.clone();
let prefix = self.prefix.clone();
tokio::spawn(async move {
let mut rx = watcher.rx;
while let Some(event) = rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let key = String::from_utf8_lossy(kv.key()).to_string();
let value = kv.value().to_vec();
tracing::debug!("KvCache update: {} = {:?}", key, value);
let mut cache_write = cache.write().await;
cache_write.insert(key, value);
}
WatchEvent::Delete(kv) => {
let key = String::from_utf8_lossy(kv.key()).to_string();
tracing::debug!("KvCache delete: {}", key);
let mut cache_write = cache.write().await;
cache_write.remove(&key);
}
}
}
tracing::info!("KvCache watcher for prefix '{}' stopped", prefix);
});
}
Ok(())
}
/// Get a value from the cache
pub async fn get(&self, key: &str) -> Option<Vec<u8>> {
let full_key = format!("{}{}", self.prefix, key);
let cache_read = self.cache.read().await;
cache_read.get(&full_key).cloned()
}
/// Get all key-value pairs in the cache
pub async fn get_all(&self) -> HashMap<String, Vec<u8>> {
let cache_read = self.cache.read().await;
cache_read.clone()
}
/// Update a value in both the cache and etcd
pub async fn put(&self, key: &str, value: Vec<u8>, lease_id: Option<i64>) -> Result<()> {
let full_key = format!("{}{}", self.prefix, key);
// Update etcd first
self.client
.kv_put(&full_key, value.clone(), lease_id)
.await?;
// Then update local cache
let mut cache_write = self.cache.write().await;
cache_write.insert(full_key, value);
Ok(())
}
// TODO: add a method to create/delete keys
}
#[cfg(feature = "integration")]
#[cfg(test)]
mod tests {
......@@ -428,4 +553,102 @@ mod tests {
Ok(())
}
#[test]
fn test_kv_cache() {
let rt = Runtime::from_settings().unwrap();
let rt_clone = rt.clone();
let config = DistributedConfig::from_settings(false);
rt_clone.primary().block_on(async move {
let drt = DistributedRuntime::new(rt, config).await.unwrap();
test_kv_cache_operations(drt).await.unwrap();
});
}
async fn test_kv_cache_operations(drt: DistributedRuntime) -> Result<()> {
// Get the client and unwrap it
let client = drt.etcd_client().expect("etcd client should be available");
// Create a unique test prefix to avoid conflicts with other tests
let test_id = uuid::Uuid::new_v4().to_string();
let prefix = format!("test_kv_cache_{}/", test_id);
// Initial values
let mut initial_values = HashMap::new();
initial_values.insert("key1".to_string(), b"value1".to_vec());
initial_values.insert("key2".to_string(), b"value2".to_vec());
// Create the KV cache
let kv_cache = KvCache::new(client.clone(), prefix.clone(), initial_values).await?;
// Test get
let value1 = kv_cache.get("key1").await;
assert_eq!(value1, Some(b"value1".to_vec()));
let value2 = kv_cache.get("key2").await;
assert_eq!(value2, Some(b"value2".to_vec()));
// Test get_all
let all_values = kv_cache.get_all().await;
assert_eq!(all_values.len(), 2);
assert_eq!(
all_values.get(&format!("{}key1", prefix)),
Some(&b"value1".to_vec())
);
assert_eq!(
all_values.get(&format!("{}key2", prefix)),
Some(&b"value2".to_vec())
);
// Test put - using None for lease_id
kv_cache.put("key3", b"value3".to_vec(), None).await?;
// Allow some time for the update to propagate
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Verify the new value
let value3 = kv_cache.get("key3").await;
assert_eq!(value3, Some(b"value3".to_vec()));
// Test update
kv_cache
.put("key1", b"updated_value1".to_vec(), None)
.await?;
// Allow some time for the update to propagate
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Verify the updated value
let updated_value1 = kv_cache.get("key1").await;
assert_eq!(updated_value1, Some(b"updated_value1".to_vec()));
// Test external update (simulating another client updating a value)
client
.kv_put(
&format!("{}key2", prefix),
b"external_update".to_vec(),
None,
)
.await?;
// Allow some time for the update to propagate
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
// Verify the cache was updated
let external_update = kv_cache.get("key2").await;
assert_eq!(external_update, Some(b"external_update".to_vec()));
// Clean up - delete the test keys
let etcd_client = client.etcd_client();
let _ = etcd_client
.kv_client()
.delete(
prefix,
Some(etcd_client::DeleteOptions::new().with_prefix()),
)
.await?;
Ok(())
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment