Unverified Commit 466b8e5f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: use primary lease + self-contained graceful shutdown trigged by SIGINT/SIGTERM (#1001)

parent b2aa2317
...@@ -174,10 +174,21 @@ class CircusController: ...@@ -174,10 +174,21 @@ class CircusController:
waiting: Whether to wait for completion waiting: Whether to wait for completion
max_retries: Maximum number of retry attempts max_retries: Maximum number of retry attempts
retry_delay: Delay between retries in seconds retry_delay: Delay between retries in seconds
timeout: Timeout in seconds for waiting for graceful exit
Returns: Returns:
True if successful, False otherwise True if successful, False otherwise
""" """
# First send SIGTERM to the process
try:
logger.info(f"Sending SIGTERM to processes in watcher {name}")
response = self.client.send_message("signal", name=name, signum="SIGTERM")
if response.get("status") != "ok":
logger.warning(f"Failed to send SIGTERM to {name}: {response}")
except Exception as e:
logger.warning(f"Error sending SIGTERM to {name}: {e}")
# Now wait for the process to exit gracefully
exited = await self._wait_for_process_graceful_exit(name, timeout) exited = await self._wait_for_process_graceful_exit(name, timeout)
if not exited: if not exited:
logger.error( logger.error(
......
...@@ -231,24 +231,6 @@ class LocalConnector(PlannerConnector): ...@@ -231,24 +231,6 @@ class LocalConnector(PlannerConnector):
target_watcher = matching_components[highest_suffix] target_watcher = matching_components[highest_suffix]
logger.info(f"Removing watcher {target_watcher}") logger.info(f"Removing watcher {target_watcher}")
pre_remove_endpoint_ids = await self._get_endpoint_ids(component_name)
if component_name == "VllmWorker" or component_name == "PrefillWorker":
lease_id = state["components"][target_watcher]["lease"]
await self._revoke_lease(lease_id)
# Poll endpoint to ensure that worker has shut down gracefully and then remove the watcher
if blocking:
required_endpoint_ids = pre_remove_endpoint_ids - 1
while True:
current_endpoint_ids = await self._get_endpoint_ids(component_name)
if current_endpoint_ids == required_endpoint_ids:
break
logger.info(
f"Waiting for {component_name} to shutdown. Current endpoint IDs: {current_endpoint_ids}, Required endpoint IDs: {required_endpoint_ids}"
)
await asyncio.sleep(5)
success = await self.circus.remove_watcher(name=target_watcher) success = await self.circus.remove_watcher(name=target_watcher)
logger.info( logger.info(
f"Circus remove_watcher for {target_watcher} {'succeeded' if success else 'failed'}" f"Circus remove_watcher for {target_watcher} {'succeeded' if success else 'failed'}"
......
...@@ -35,7 +35,6 @@ from fastapi.responses import StreamingResponse ...@@ -35,7 +35,6 @@ from fastapi.responses import StreamingResponse
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.sdk import dynamo_context from dynamo.sdk import dynamo_context
from dynamo.sdk.cli.utils import append_dynamo_state
from dynamo.sdk.lib.service import LinkedServices from dynamo.sdk.lib.service import LinkedServices
from dynamo.sdk.lib.utils import get_host_port from dynamo.sdk.lib.utils import get_host_port
...@@ -208,18 +207,6 @@ def main( ...@@ -208,18 +207,6 @@ def main(
component = runtime.namespace(namespace).component(component_name) component = runtime.namespace(namespace).component(component_name)
try: try:
# if a custom lease is specified we need to create the service with that lease
lease = None
if service._dynamo_config.custom_lease:
lease = await component.create_service_with_custom_lease(
ttl=service._dynamo_config.custom_lease.ttl
)
lease_id = lease.id()
dynamo_context["lease"] = lease
logger.info(
f"Created {service.name} component with custom lease id {lease_id}"
)
else:
# Create service first # Create service first
await component.create_service() await component.create_service()
logger.info(f"Created {service.name} component") logger.info(f"Created {service.name} component")
...@@ -272,31 +259,15 @@ def main( ...@@ -272,31 +259,15 @@ def main(
f"Starting {service.name} instance with all registered endpoints" f"Starting {service.name} instance with all registered endpoints"
) )
# TODO:bis: convert to list # TODO:bis: convert to list
if lease is None:
logger.info(f"Serving {service.name} with primary lease") logger.info(f"Serving {service.name} with primary lease")
else:
logger.info(f"Serving {service.name} with lease: {lease.id()}")
# Map custom lease to component
watcher_name = None
if custom_component_name:
watcher_name = custom_component_name
else:
watcher_name = f"{namespace}_{component_name}"
append_dynamo_state(namespace, watcher_name, {"lease": lease.id()})
logger.info(
f"Appended lease {lease.id()}/{lease.id():x} to {watcher_name}"
)
# Launch serve_endpoint for all endpoints concurrently # Launch serve_endpoint for all endpoints concurrently
tasks = [ tasks = [
endpoint.serve_endpoint(handler, lease) endpoint.serve_endpoint(handler)
for endpoint, handler in zip(endpoints, dynamo_handlers) for endpoint, handler in zip(endpoints, dynamo_handlers)
] ]
# Wait for all tasks to complete # Wait for all tasks to complete
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if class_instance.__class__.__name__ == "PrefillWorker":
await asyncio.wait_for(class_instance.task, timeout=None)
except GracefulExit: except GracefulExit:
logger.info(f"[{run_id}] Gracefully shutting down {service.name}") logger.info(f"[{run_id}] Gracefully shutting down {service.name}")
# Add any specific cleanup needed # Add any specific cleanup needed
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import asyncio import asyncio
import logging import logging
import os import os
import signal
import sys import sys
from pydantic import BaseModel from pydantic import BaseModel
...@@ -30,7 +31,6 @@ from vllm.inputs.data import TokensPrompt ...@@ -30,7 +31,6 @@ from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.service import LeaseConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -43,7 +43,6 @@ class RequestType(BaseModel): ...@@ -43,7 +43,6 @@ class RequestType(BaseModel):
dynamo={ dynamo={
"enabled": True, "enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
"custom_lease": LeaseConfig(ttl=1), # 1 second
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
...@@ -100,9 +99,32 @@ class PrefillWorker: ...@@ -100,9 +99,32 @@ class PrefillWorker:
sys.exit(1) sys.exit(1)
self.task.add_done_callback(prefill_queue_handler_cb) self.task.add_done_callback(prefill_queue_handler_cb)
self.lease = dynamo_context["lease"]
self.shutdown_requested = False
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("PrefillWorker initialized") logger.info("PrefillWorker initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
# first shutdown the vllm engine
self.shutdown_requested = True
await asyncio.wait_for(self.task, timeout=None)
# then shutdown the mock endpoint
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self): def shutdown_vllm_engine(self):
"""Shutdown the background loop""" """Shutdown the background loop"""
logger.info("Shutting down vllm engine") logger.info("Shutting down vllm engine")
...@@ -140,8 +162,7 @@ class PrefillWorker: ...@@ -140,8 +162,7 @@ class PrefillWorker:
) )
async for _ in self.generate(prefill_request): async for _ in self.generate(prefill_request):
pass pass
is_valid = await self.lease.is_valid() if self.shutdown_requested:
if not is_valid:
logger.info( logger.info(
"Shutdown requested, checking if engine has any pending prefill sending requests" "Shutdown requested, checking if engine has any pending prefill sending requests"
) )
......
...@@ -33,7 +33,6 @@ from vllm.sampling_params import RequestOutputKind ...@@ -33,7 +33,6 @@ from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.service import LeaseConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -42,7 +41,6 @@ logger = logging.getLogger(__name__) ...@@ -42,7 +41,6 @@ logger = logging.getLogger(__name__)
dynamo={ dynamo={
"enabled": True, "enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
"custom_lease": LeaseConfig(ttl=1), # 1 second
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
...@@ -138,8 +136,24 @@ class VllmWorker: ...@@ -138,8 +136,24 @@ class VllmWorker:
else: else:
self.disaggregated_router = None self.disaggregated_router = None
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("VllmWorker has been initialized") logger.info("VllmWorker has been initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop""" """Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down") logger.info(f"Received signal {signum}, shutting down")
...@@ -154,12 +168,8 @@ class VllmWorker: ...@@ -154,12 +168,8 @@ class VllmWorker:
async def create_metrics_publisher_endpoint(self): async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"] component = dynamo_context["component"]
lease = dynamo_context["lease"]
if lease is None:
logger.info("Creating metrics publisher endpoint with primary lease") logger.info("Creating metrics publisher endpoint with primary lease")
else: await self.metrics_publisher.create_endpoint(component)
logger.info(f"Creating metrics publisher endpoint with lease: {lease}")
await self.metrics_publisher.create_endpoint(component, lease)
def get_remote_prefill_request_callback(self): def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint # TODO: integrate prefill_queue to dynamo endpoint
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
parser = argparse.ArgumentParser()
parser.add_argument("lease_id", type=int, help="Lease ID to revoke")
args = parser.parse_args()
await init(runtime, args.lease_id)
async def init(runtime: DistributedRuntime, lease_id: int):
client = runtime.etcd_client()
await client.revoke_lease(lease_id)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import asyncio import asyncio
import signal
import uvloop import uvloop
...@@ -37,25 +38,41 @@ async def worker(runtime: DistributedRuntime): ...@@ -37,25 +38,41 @@ async def worker(runtime: DistributedRuntime):
print( print(
f"Primary lease ID: {runtime.etcd_client().primary_lease_id()}/{runtime.etcd_client().primary_lease_id():#x}" f"Primary lease ID: {runtime.etcd_client().primary_lease_id()}/{runtime.etcd_client().primary_lease_id():#x}"
) )
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
print("Signal handlers registered for graceful shutdown")
await init(runtime, "dynamo") await init(runtime, "dynamo")
async def graceful_shutdown(runtime: DistributedRuntime):
print("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
print("DistributedRuntime shutdown complete")
async def init(runtime: DistributedRuntime, ns: str): async def init(runtime: DistributedRuntime, ns: str):
""" """
Instantiate a `backend` component and serve the `generate` endpoint Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints A `Component` can serve multiple endpoints
""" """
component = runtime.namespace(ns).component("backend") component = runtime.namespace(ns).component("backend")
lease = await component.create_service_with_custom_lease(ttl=1) await component.create_service()
lease_id = lease.id()
print(f"Created custom lease with ID: {lease_id}/{lease_id:#x}")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
print("Started server instance") print("Started server instance")
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes) # the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked # after the lease is revoked
await endpoint.serve_endpoint(RequestHandler().generate, lease) await endpoint.serve_endpoint(RequestHandler().generate)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -185,12 +185,6 @@ struct Client { ...@@ -185,12 +185,6 @@ struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>, router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
} }
#[pyclass]
#[derive(Clone)]
struct PyLease {
inner: rs::transports::etcd::Lease,
}
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)] #[derive(Clone, PartialEq)]
#[repr(i32)] #[repr(i32)]
...@@ -200,25 +194,6 @@ enum ModelType { ...@@ -200,25 +194,6 @@ enum ModelType {
Backend = 3, Backend = 3,
} }
#[pymethods]
impl PyLease {
fn id(&self) -> i64 {
self.inner.id()
}
fn revoke(&self) {
self.inner.revoke();
}
fn is_valid<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let lease = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let is_valid = lease.is_valid().await.map_err(to_pyerr)?;
Ok(is_valid)
})
}
}
#[pymethods] #[pymethods]
impl DistributedRuntime { impl DistributedRuntime {
#[new] #[new]
...@@ -263,16 +238,6 @@ impl DistributedRuntime { ...@@ -263,16 +238,6 @@ impl DistributedRuntime {
} }
} }
fn primary_token(&self) -> CancellationToken {
let inner = self.inner.runtime().primary_token();
CancellationToken { inner }
}
fn child_token(&self) -> CancellationToken {
let inner = self.inner.runtime().child_token();
CancellationToken { inner }
}
fn shutdown(&self) { fn shutdown(&self) {
self.inner.runtime().shutdown(); self.inner.runtime().shutdown();
} }
...@@ -461,61 +426,22 @@ impl Component { ...@@ -461,61 +426,22 @@ impl Component {
Ok(()) Ok(())
}) })
} }
#[pyo3(signature = (ttl=1))]
fn create_service_with_custom_lease<'p>(
&self,
py: Python<'p>,
ttl: i64,
) -> PyResult<Bound<'p, PyAny>> {
let component = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Get the etcd client from the runtime
let etcd_client = component
.drt()
.etcd_client()
.ok_or_else(|| to_pyerr("etcd client not found"))?;
// Create a custom lease with the specified TTL
let custom_lease = etcd_client.create_lease(ttl).await.map_err(to_pyerr)?;
tracing::info!("created custom lease: {:?}", custom_lease);
// Create a service
// TODO: tie the lease to service instead of endpoint
let _service = component
.service_builder()
.create()
.await
.map_err(to_pyerr)?;
// Return the lease
Ok(PyLease {
inner: custom_lease,
})
})
}
} }
#[pymethods] #[pymethods]
impl Endpoint { impl Endpoint {
#[pyo3(signature = (generator, lease=None))] #[pyo3(signature = (generator))]
fn serve_endpoint<'p>( fn serve_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
generator: PyObject, generator: PyObject,
lease: Option<&PyLease>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new( let engine = Arc::new(engine::PythonAsyncEngine::new(
generator, generator,
self.event_loop.clone(), self.event_loop.clone(),
)?); )?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?; let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let mut builder = self.inner.endpoint_builder().handler(ingress); let builder = self.inner.endpoint_builder().handler(ingress);
if lease.is_some() {
builder = builder.lease(lease.map(|l| l.inner.clone()));
}
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?; builder.start().await.map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -73,19 +73,17 @@ impl KvMetricsPublisher { ...@@ -73,19 +73,17 @@ impl KvMetricsPublisher {
}) })
} }
#[pyo3(signature = (component, lease=None))] #[pyo3(signature = (component))]
fn create_endpoint<'p>( fn create_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
component: Component, component: Component,
lease: Option<&PyLease>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone(); let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone(); let rs_component = component.inner.clone();
let lease = lease.map(|l| l.inner.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher rs_publisher
.create_endpoint(rs_component, lease) .create_endpoint(rs_component)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -49,31 +49,11 @@ class DistributedRuntime: ...@@ -49,31 +49,11 @@ class DistributedRuntime:
""" """
... ...
class PyLease: def shutdown(self) -> None:
"""
A lease object
"""
def id(self) -> int:
"""
Return the id of the lease
Refer to https://etcd.io/docs/v3.4/learning/api/ for examples on how to use the lease id
"""
...
def revoke(self) -> None:
"""
Revoke the lease by triggering the cancellation token
This will invalidate the kv pairs associated with this lease
"""
...
def is_valid(self) -> bool:
""" """
Check if the lease is still valid (not revoked) Shutdown the runtime by triggering the cancellation token
""" """
... ...
class EtcdClient: class EtcdClient:
""" """
Etcd is used for discovery in the DistributedRuntime Etcd is used for discovery in the DistributedRuntime
...@@ -232,14 +212,6 @@ class Component: ...@@ -232,14 +212,6 @@ class Component:
""" """
... ...
def create_service_with_custom_lease(self, ttl: int) -> PyLease:
"""
Create a service with a custom lease
The lease needs to be tied to the endpoint of this services when creating the endpoints later
TODO: tie the lease to the service instead of the endpoint
"""
...
class Endpoint: class Endpoint:
""" """
An Endpoint is a single API endpoint An Endpoint is a single API endpoint
......
...@@ -23,7 +23,6 @@ use dynamo_runtime::{ ...@@ -23,7 +23,6 @@ use dynamo_runtime::{
SingleIn, SingleIn,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
transports::etcd::Lease,
Error, Result, Error, Result,
}; };
use futures::stream; use futures::stream;
...@@ -93,12 +92,12 @@ impl KvMetricsPublisher { ...@@ -93,12 +92,12 @@ impl KvMetricsPublisher {
self.tx.send(metrics) self.tx.send(metrics)
} }
pub async fn create_endpoint(&self, component: Component, lease: Option<Lease>) -> Result<()> { pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone(); let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone())); let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?; let handler = Ingress::for_engine(handler)?;
let builder = component component
.endpoint(KV_METRICS_ENDPOINT) .endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder() .endpoint_builder()
.stats_handler(move |_| { .stats_handler(move |_| {
...@@ -106,9 +105,8 @@ impl KvMetricsPublisher { ...@@ -106,9 +105,8 @@ impl KvMetricsPublisher {
serde_json::to_value(&*metrics).unwrap() serde_json::to_value(&*metrics).unwrap()
}) })
.handler(handler) .handler(handler)
.lease(lease); .start()
.await
builder.start().await
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment