Unverified Commit a590d103 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: cleanup EtcdKvCache and PrefillQueue before and after launch (#925)

parent 10e91264
...@@ -123,6 +123,37 @@ def create_dynamo_watcher( ...@@ -123,6 +123,37 @@ def create_dynamo_watcher(
return watcher, socket, uri return watcher, socket, uri
def clear_namespace(namespace: str) -> None:
"""
Check if utils/clear_namespace.py exists and run it to clear the namespace.
"""
import os.path
import subprocess
clear_script_path = "utils/clear_namespace.py"
if os.path.exists(clear_script_path):
logger.info(f"Clearing namespace {namespace} using {clear_script_path}")
try:
# Run the script and wait for it to complete
result = subprocess.run(
["python", "-m", "utils.clear_namespace", "--namespace", namespace],
check=True,
capture_output=True,
text=True,
)
logger.info(f"Clear namespace output: {result.stdout}")
logger.info(f"Successfully cleared namespace {namespace}")
if result.stderr:
logger.info(f"Clear namespace stderr: {result.stderr}")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to clear namespace {namespace}: {e.stderr}")
else:
logger.debug(
f"Script not found at {clear_script_path}, skip namespace clearing"
)
@inject(squeeze_none=True) @inject(squeeze_none=True)
def serve_dynamo_graph( def serve_dynamo_graph(
bento_identifier: str | AnyService, bento_identifier: str | AnyService,
...@@ -174,10 +205,32 @@ def serve_dynamo_graph( ...@@ -174,10 +205,32 @@ def serve_dynamo_graph(
try: try:
if not service_name and not standalone: if not service_name and not standalone:
with contextlib.ExitStack() as port_stack: with contextlib.ExitStack() as port_stack:
# first check if all components has the same namespace
namespaces = set()
for name, dep_svc in svc.all_services().items(): for name, dep_svc in svc.all_services().items():
if name == svc.name: if name == svc.name or name in dependency_map:
continue continue
if name in dependency_map: if not (
hasattr(dep_svc, "is_dynamo_component")
and dep_svc.is_dynamo_component()
):
raise RuntimeError(
f"Service {dep_svc.name} is not a Dynamo component"
)
namespaces.add(dep_svc.dynamo_address()[0])
if len(namespaces) > 1:
raise RuntimeError(
f"All components must have the same namespace, got {namespaces}"
)
else:
namespace = namespaces.pop() if namespaces else ""
logger.info(f"Serving dynamo graph with namespace {namespace}")
# clear residue etcd/nats entry (if any) under this namespace
logger.info(f"Clearing namespace {namespace} before serving")
clear_namespace(namespace)
for name, dep_svc in svc.all_services().items():
if name == svc.name or name in dependency_map:
continue continue
if not ( if not (
hasattr(dep_svc, "is_dynamo_component") hasattr(dep_svc, "is_dynamo_component")
...@@ -194,7 +247,6 @@ def serve_dynamo_graph( ...@@ -194,7 +247,6 @@ def serve_dynamo_graph(
str(bento_path.absolute()), str(bento_path.absolute()),
env=env, env=env,
) )
namespace, _ = dep_svc.dynamo_address()
watchers.append(new_watcher) watchers.append(new_watcher)
sockets.append(new_socket) sockets.append(new_socket)
dependency_map[name] = uri dependency_map[name] = uri
...@@ -263,6 +315,7 @@ def serve_dynamo_graph( ...@@ -263,6 +315,7 @@ def serve_dynamo_graph(
} }
arbiter = create_arbiter(**arbiter_kwargs) arbiter = create_arbiter(**arbiter_kwargs)
arbiter.exit_stack.callback(clear_namespace, namespace)
arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True) arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True)
if enable_local_planner: if enable_local_planner:
arbiter.exit_stack.callback( arbiter.exit_stack.callback(
......
...@@ -25,12 +25,12 @@ class PyDisaggregatedRouter: ...@@ -25,12 +25,12 @@ class PyDisaggregatedRouter:
def __init__( def __init__(
self, self,
runtime, runtime,
served_model_name, namespace,
max_local_prefill_length=1000, max_local_prefill_length=1000,
max_prefill_queue_size=2, max_prefill_queue_size=2,
): ):
self.runtime = runtime self.runtime = runtime
self.served_model_name = served_model_name self.namespace = namespace
self.max_local_prefill_length = max_local_prefill_length self.max_local_prefill_length = max_local_prefill_length
self.max_prefill_queue_size = max_prefill_queue_size self.max_prefill_queue_size = max_prefill_queue_size
...@@ -38,7 +38,7 @@ class PyDisaggregatedRouter: ...@@ -38,7 +38,7 @@ class PyDisaggregatedRouter:
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
self.etcd_kv_cache = await EtcdKvCache.create( self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(), runtime.etcd_client(),
"/dynamo/disagg_router/", f"/{self.namespace}/disagg_router/",
{ {
"max_local_prefill_length": str(self.max_local_prefill_length), "max_local_prefill_length": str(self.max_local_prefill_length),
"max_prefill_queue_size": str(self.max_prefill_queue_size), "max_prefill_queue_size": str(self.max_prefill_queue_size),
......
...@@ -21,7 +21,7 @@ from argparse import Namespace ...@@ -21,7 +21,7 @@ from argparse import Namespace
from typing import AsyncIterator, Tuple from typing import AsyncIterator, Tuple
from components.worker import VllmWorker from components.worker import VllmWorker
from utils.logging import check_required_workers from utils.check_worker import check_required_workers
from utils.protocol import Tokens from utils.protocol import Tokens
from utils.vllm import RouterType from utils.vllm import RouterType
......
...@@ -118,11 +118,8 @@ class PrefillWorker: ...@@ -118,11 +118,8 @@ class PrefillWorker:
async def prefill_queue_handler(self): async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered") logger.info("Prefill queue handler entered")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222") prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = ( namespace, _ = PrefillWorker.dynamo_address() # type: ignore
self.engine_args.served_model_name prefill_queue_stream_name = f"{namespace}_prefill_queue"
if self.engine_args.served_model_name is not None
else "vllm"
)
logger.info( logger.info(
f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}" f"Prefill queue: {prefill_queue_nats_server}:{prefill_queue_stream_name}"
) )
......
...@@ -23,7 +23,7 @@ from components.kv_router import Router ...@@ -23,7 +23,7 @@ from components.kv_router import Router
from components.worker import VllmWorker from components.worker import VllmWorker
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers from utils.check_worker import check_required_workers
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import RouterType, parse_vllm_args from utils.vllm import RouterType, parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
...@@ -121,7 +121,7 @@ class Processor(ProcessMixIn): ...@@ -121,7 +121,7 @@ class Processor(ProcessMixIn):
self.etcd_kv_cache = await EtcdKvCache.create( self.etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(), runtime.etcd_client(),
"/dynamo/processor/", f"/{comp_ns}/processor/",
{"router": self.engine_args.router}, {"router": self.engine_args.router},
) )
......
...@@ -56,15 +56,11 @@ class VllmWorker: ...@@ -56,15 +56,11 @@ class VllmWorker:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
self.do_remote_prefill = self.engine_args.remote_prefill self.do_remote_prefill = self.engine_args.remote_prefill
self.model_name = (
self.engine_args.served_model_name
if self.engine_args.served_model_name is not None
else "vllm"
)
self._prefill_queue_nats_server = os.getenv( self._prefill_queue_nats_server = os.getenv(
"NATS_SERVER", "nats://localhost:4222" "NATS_SERVER", "nats://localhost:4222"
) )
self._prefill_queue_stream_name = self.model_name self.namespace, _ = VllmWorker.dynamo_address() # type: ignore
self._prefill_queue_stream_name = f"{self.namespace}_prefill_queue"
logger.info( logger.info(
f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}" f"Prefill queue: {self._prefill_queue_nats_server}:{self._prefill_queue_stream_name}"
) )
...@@ -134,7 +130,7 @@ class VllmWorker: ...@@ -134,7 +130,7 @@ class VllmWorker:
if self.engine_args.conditional_disagg: if self.engine_args.conditional_disagg:
self.disaggregated_router = PyDisaggregatedRouter( self.disaggregated_router = PyDisaggregatedRouter(
runtime, runtime,
self.model_name, self.namespace,
max_local_prefill_length=self.engine_args.max_local_prefill_length, max_local_prefill_length=self.engine_args.max_local_prefill_length,
max_prefill_queue_size=self.engine_args.max_prefill_queue_size, max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
) )
......
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import argparse
import asyncio
import logging
import os
from utils.prefill_queue import PrefillQueue
from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
@dynamo_worker()
async def clear_namespace(runtime: DistributedRuntime, namespace: str):
etcd_kv_cache = await EtcdKvCache.create(
runtime.etcd_client(),
f"/{namespace}/",
{},
)
await etcd_kv_cache.clear_all()
logger.info(f"Cleared /{namespace} in EtcdKvCache")
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = f"{namespace}_prefill_queue"
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
dequeue_timeout=3,
) as prefill_queue:
cleared_count = await prefill_queue.clear_queue()
logger.info(
f"Cleared {cleared_count} requests from prefill queue{prefill_queue_stream_name}"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--namespace", type=str, required=True)
args = parser.parse_args()
asyncio.run(clear_namespace(args.namespace))
...@@ -82,8 +82,22 @@ class NATSQueue: ...@@ -82,8 +82,22 @@ class NATSQueue:
async def enqueue_task(self, task_data: bytes) -> None: async def enqueue_task(self, task_data: bytes) -> None:
await self.nats_q.enqueue_task(task_data) await self.nats_q.enqueue_task(task_data)
async def dequeue_task(self) -> Optional[bytes]: async def dequeue_task(self, timeout: Optional[float] = None) -> Optional[bytes]:
return await self.nats_q.dequeue_task() return await self.nats_q.dequeue_task(timeout)
async def get_queue_size(self) -> int: async def get_queue_size(self) -> int:
return await self.nats_q.get_queue_size() return await self.nats_q.get_queue_size()
async def clear_queue(self) -> int:
try:
cleared_count = 0
# Continue until we can't dequeue any more messages
while True:
# use a small timeout
message = await self.dequeue_task(timeout=0.1)
if message is None:
break
cleared_count += 1
return cleared_count
except Exception as e:
raise RuntimeError(f"Failed to clear queue: {e}")
...@@ -392,6 +392,41 @@ impl EtcdKvCache { ...@@ -392,6 +392,41 @@ impl EtcdKvCache {
Ok(()) Ok(())
}) })
} }
fn delete<'p>(&self, py: Python<'p>, key: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.delete(&key).await.map_err(to_pyerr)?;
Ok(())
})
}
fn clear_all<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Get all keys with the prefix
let all_keys = inner
.get_all()
.await
.keys()
.cloned()
.collect::<Vec<String>>();
// Delete each key
for key in all_keys {
// Strip the prefix from the key before deleting
if let Some(stripped_key) = key.strip_prefix(&inner.prefix) {
inner.delete(stripped_key).await.map_err(to_pyerr)?;
} else {
inner.delete(&key).await.map_err(to_pyerr)?;
}
}
Ok(())
})
}
} }
#[pymethods] #[pymethods]
......
...@@ -81,13 +81,19 @@ impl NatsQueue { ...@@ -81,13 +81,19 @@ impl NatsQueue {
}) })
} }
fn dequeue_task<'p>(&mut self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { #[pyo3(signature = (timeout=None))]
fn dequeue_task<'p>(
&mut self,
py: Python<'p>,
timeout: Option<f64>,
) -> PyResult<Bound<'p, PyAny>> {
let queue = self.inner.clone(); let queue = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let timeout_duration = timeout.map(std::time::Duration::from_secs_f64);
Ok(queue Ok(queue
.lock() .lock()
.await .await
.dequeue_task() .dequeue_task(timeout_duration)
.await .await
.map_err(to_pyerr)? .map_err(to_pyerr)?
.map(|bytes| bytes.to_vec())) .map(|bytes| bytes.to_vec()))
......
...@@ -188,6 +188,18 @@ class EtcdKvCache: ...@@ -188,6 +188,18 @@ class EtcdKvCache:
""" """
... ...
async def delete(self, key: str) -> None:
"""
Delete a key-value pair from the cache and etcd.
"""
...
async def clear_all(self) -> None:
"""
Delete all key-value pairs from the cache and etcd.
"""
...
class Namespace: class Namespace:
""" """
A namespace is a collection of components A namespace is a collection of components
...@@ -614,3 +626,68 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st ...@@ -614,3 +626,68 @@ async def register_llm(model_type: ModelType, endpoint: Endpoint, model_path: st
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """Attach the model at path to the given endpoint, and advertise it as model_type"""
... ...
class NatsQueue:
"""
A queue implementation using NATS JetStream for task distribution
"""
def __init__(self, stream_name: str, nats_server: str, dequeue_timeout: float) -> None:
"""
Create a new NatsQueue instance.
Args:
stream_name: Name of the NATS JetStream stream
nats_server: URL of the NATS server
dequeue_timeout: Default timeout in seconds for dequeue operations
"""
...
async def connect(self) -> None:
"""
Connect to the NATS server
"""
...
async def ensure_connection(self) -> None:
"""
Ensure connection to the NATS server, connecting if not already connected
"""
...
async def close(self) -> None:
"""
Close the connection to the NATS server
"""
...
async def enqueue_task(self, task_data: bytes) -> None:
"""
Enqueue a task to the NATS JetStream
Args:
task_data: The task data as bytes
"""
...
async def dequeue_task(self, timeout: Optional[float] = None) -> Optional[bytes]:
"""
Dequeue a task from the NATS JetStream
Args:
timeout: Optional timeout in seconds for this specific dequeue operation.
If None, uses the default timeout specified during initialization.
Returns:
The task data as bytes if available, None if no task is available
"""
...
async def get_queue_size(self) -> int:
"""
Get the current size of the queue
Returns:
The number of messages in the queue
"""
...
...@@ -551,7 +551,19 @@ impl KvCache { ...@@ -551,7 +551,19 @@ impl KvCache {
Ok(()) Ok(())
} }
// TODO: add a method to create/delete keys /// Delete a key from both the cache and etcd
pub async fn delete(&self, key: &str) -> Result<()> {
let full_key = format!("{}{}", self.prefix, key);
// Delete from etcd first
self.client.kv_delete(full_key.clone(), None).await?;
// Then remove from local cache
let mut cache_write = self.cache.write().await;
cache_write.remove(&full_key);
Ok(())
}
} }
#[cfg(feature = "integration")] #[cfg(feature = "integration")]
......
...@@ -440,13 +440,14 @@ impl NatsQueue { ...@@ -440,13 +440,14 @@ impl NatsQueue {
} }
/// Dequeue and return a task as raw bytes /// Dequeue and return a task as raw bytes
pub async fn dequeue_task(&mut self) -> Result<Option<Bytes>> { pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
self.ensure_connection().await?; self.ensure_connection().await?;
if let Some(subscriber) = &self.subscriber { if let Some(subscriber) = &self.subscriber {
let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
let mut batch = subscriber let mut batch = subscriber
.fetch() .fetch()
.expires(self.dequeue_timeout) .expires(timeout_duration)
.max_messages(1) .max_messages(1)
.messages() .messages()
.await?; .await?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment