"launch/dynamo-run/src/main.rs" did not exist on "c70de37fcb50559ecd051df140db38250077c245"
Unverified Commit 466b8e5f authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: use primary lease + self-contained graceful shutdown trigged by SIGINT/SIGTERM (#1001)

parent b2aa2317
......@@ -174,10 +174,21 @@ class CircusController:
waiting: Whether to wait for completion
max_retries: Maximum number of retry attempts
retry_delay: Delay between retries in seconds
timeout: Timeout in seconds for waiting for graceful exit
Returns:
True if successful, False otherwise
"""
# First send SIGTERM to the process
try:
logger.info(f"Sending SIGTERM to processes in watcher {name}")
response = self.client.send_message("signal", name=name, signum="SIGTERM")
if response.get("status") != "ok":
logger.warning(f"Failed to send SIGTERM to {name}: {response}")
except Exception as e:
logger.warning(f"Error sending SIGTERM to {name}: {e}")
# Now wait for the process to exit gracefully
exited = await self._wait_for_process_graceful_exit(name, timeout)
if not exited:
logger.error(
......
......@@ -231,24 +231,6 @@ class LocalConnector(PlannerConnector):
target_watcher = matching_components[highest_suffix]
logger.info(f"Removing watcher {target_watcher}")
pre_remove_endpoint_ids = await self._get_endpoint_ids(component_name)
if component_name == "VllmWorker" or component_name == "PrefillWorker":
lease_id = state["components"][target_watcher]["lease"]
await self._revoke_lease(lease_id)
# Poll endpoint to ensure that worker has shut down gracefully and then remove the watcher
if blocking:
required_endpoint_ids = pre_remove_endpoint_ids - 1
while True:
current_endpoint_ids = await self._get_endpoint_ids(component_name)
if current_endpoint_ids == required_endpoint_ids:
break
logger.info(
f"Waiting for {component_name} to shutdown. Current endpoint IDs: {current_endpoint_ids}, Required endpoint IDs: {required_endpoint_ids}"
)
await asyncio.sleep(5)
success = await self.circus.remove_watcher(name=target_watcher)
logger.info(
f"Circus remove_watcher for {target_watcher} {'succeeded' if success else 'failed'}"
......
......@@ -35,7 +35,6 @@ from fastapi.responses import StreamingResponse
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.sdk import dynamo_context
from dynamo.sdk.cli.utils import append_dynamo_state
from dynamo.sdk.lib.service import LinkedServices
from dynamo.sdk.lib.utils import get_host_port
......@@ -208,18 +207,6 @@ def main(
component = runtime.namespace(namespace).component(component_name)
try:
# if a custom lease is specified we need to create the service with that lease
lease = None
if service._dynamo_config.custom_lease:
lease = await component.create_service_with_custom_lease(
ttl=service._dynamo_config.custom_lease.ttl
)
lease_id = lease.id()
dynamo_context["lease"] = lease
logger.info(
f"Created {service.name} component with custom lease id {lease_id}"
)
else:
# Create service first
await component.create_service()
logger.info(f"Created {service.name} component")
......@@ -272,31 +259,15 @@ def main(
f"Starting {service.name} instance with all registered endpoints"
)
# TODO:bis: convert to list
if lease is None:
logger.info(f"Serving {service.name} with primary lease")
else:
logger.info(f"Serving {service.name} with lease: {lease.id()}")
# Map custom lease to component
watcher_name = None
if custom_component_name:
watcher_name = custom_component_name
else:
watcher_name = f"{namespace}_{component_name}"
append_dynamo_state(namespace, watcher_name, {"lease": lease.id()})
logger.info(
f"Appended lease {lease.id()}/{lease.id():x} to {watcher_name}"
)
# Launch serve_endpoint for all endpoints concurrently
tasks = [
endpoint.serve_endpoint(handler, lease)
endpoint.serve_endpoint(handler)
for endpoint, handler in zip(endpoints, dynamo_handlers)
]
# Wait for all tasks to complete
await asyncio.gather(*tasks)
if class_instance.__class__.__name__ == "PrefillWorker":
await asyncio.wait_for(class_instance.task, timeout=None)
except GracefulExit:
logger.info(f"[{run_id}] Gracefully shutting down {service.name}")
# Add any specific cleanup needed
......
......@@ -17,6 +17,7 @@
import asyncio
import logging
import os
import signal
import sys
from pydantic import BaseModel
......@@ -30,7 +31,6 @@ from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.service import LeaseConfig
logger = logging.getLogger(__name__)
......@@ -43,7 +43,6 @@ class RequestType(BaseModel):
dynamo={
"enabled": True,
"namespace": "dynamo",
"custom_lease": LeaseConfig(ttl=1), # 1 second
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
......@@ -100,9 +99,32 @@ class PrefillWorker:
sys.exit(1)
self.task.add_done_callback(prefill_queue_handler_cb)
self.lease = dynamo_context["lease"]
self.shutdown_requested = False
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("PrefillWorker initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
# first shutdown the vllm engine
self.shutdown_requested = True
await asyncio.wait_for(self.task, timeout=None)
# then shutdown the mock endpoint
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self):
"""Shutdown the background loop"""
logger.info("Shutting down vllm engine")
......@@ -140,8 +162,7 @@ class PrefillWorker:
)
async for _ in self.generate(prefill_request):
pass
is_valid = await self.lease.is_valid()
if not is_valid:
if self.shutdown_requested:
logger.info(
"Shutdown requested, checking if engine has any pending prefill sending requests"
)
......
......@@ -33,7 +33,6 @@ from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.service import LeaseConfig
logger = logging.getLogger(__name__)
......@@ -42,7 +41,6 @@ logger = logging.getLogger(__name__)
dynamo={
"enabled": True,
"namespace": "dynamo",
"custom_lease": LeaseConfig(ttl=1), # 1 second
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
......@@ -138,8 +136,24 @@ class VllmWorker:
else:
self.disaggregated_router = None
# Set up signal handler for graceful shutdown
# TODO: move to dynamo sdk
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(self.graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info("VllmWorker has been initialized")
async def graceful_shutdown(self, runtime):
logger.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
logger.info("DistributedRuntime shutdown complete")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
......@@ -154,12 +168,8 @@ class VllmWorker:
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
lease = dynamo_context["lease"]
if lease is None:
logger.info("Creating metrics publisher endpoint with primary lease")
else:
logger.info(f"Creating metrics publisher endpoint with lease: {lease}")
await self.metrics_publisher.create_endpoint(component, lease)
await self.metrics_publisher.create_endpoint(component)
def get_remote_prefill_request_callback(self):
# TODO: integrate prefill_queue to dynamo endpoint
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
parser = argparse.ArgumentParser()
parser.add_argument("lease_id", type=int, help="Lease ID to revoke")
args = parser.parse_args()
await init(runtime, args.lease_id)
async def init(runtime: DistributedRuntime, lease_id: int):
client = runtime.etcd_client()
await client.revoke_lease(lease_id)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -14,6 +14,7 @@
# limitations under the License.
import asyncio
import signal
import uvloop
......@@ -37,25 +38,41 @@ async def worker(runtime: DistributedRuntime):
print(
f"Primary lease ID: {runtime.etcd_client().primary_lease_id()}/{runtime.etcd_client().primary_lease_id():#x}"
)
# Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop()
def signal_handler():
# Schedule the shutdown coroutine instead of calling it directly
asyncio.create_task(graceful_shutdown(runtime))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
print("Signal handlers registered for graceful shutdown")
await init(runtime, "dynamo")
async def graceful_shutdown(runtime: DistributedRuntime):
print("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
print("DistributedRuntime shutdown complete")
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
lease = await component.create_service_with_custom_lease(ttl=1)
lease_id = lease.id()
print(f"Created custom lease with ID: {lease_id}/{lease_id:#x}")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started server instance")
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler().generate, lease)
await endpoint.serve_endpoint(RequestHandler().generate)
if __name__ == "__main__":
......
......@@ -185,12 +185,6 @@ struct Client {
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
}
#[pyclass]
#[derive(Clone)]
struct PyLease {
inner: rs::transports::etcd::Lease,
}
#[pyclass(eq, eq_int)]
#[derive(Clone, PartialEq)]
#[repr(i32)]
......@@ -200,25 +194,6 @@ enum ModelType {
Backend = 3,
}
#[pymethods]
impl PyLease {
fn id(&self) -> i64 {
self.inner.id()
}
fn revoke(&self) {
self.inner.revoke();
}
fn is_valid<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let lease = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let is_valid = lease.is_valid().await.map_err(to_pyerr)?;
Ok(is_valid)
})
}
}
#[pymethods]
impl DistributedRuntime {
#[new]
......@@ -263,16 +238,6 @@ impl DistributedRuntime {
}
}
fn primary_token(&self) -> CancellationToken {
let inner = self.inner.runtime().primary_token();
CancellationToken { inner }
}
fn child_token(&self) -> CancellationToken {
let inner = self.inner.runtime().child_token();
CancellationToken { inner }
}
fn shutdown(&self) {
self.inner.runtime().shutdown();
}
......@@ -461,61 +426,22 @@ impl Component {
Ok(())
})
}
#[pyo3(signature = (ttl=1))]
fn create_service_with_custom_lease<'p>(
&self,
py: Python<'p>,
ttl: i64,
) -> PyResult<Bound<'p, PyAny>> {
let component = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Get the etcd client from the runtime
let etcd_client = component
.drt()
.etcd_client()
.ok_or_else(|| to_pyerr("etcd client not found"))?;
// Create a custom lease with the specified TTL
let custom_lease = etcd_client.create_lease(ttl).await.map_err(to_pyerr)?;
tracing::info!("created custom lease: {:?}", custom_lease);
// Create a service
// TODO: tie the lease to service instead of endpoint
let _service = component
.service_builder()
.create()
.await
.map_err(to_pyerr)?;
// Return the lease
Ok(PyLease {
inner: custom_lease,
})
})
}
}
#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator, lease=None))]
#[pyo3(signature = (generator))]
fn serve_endpoint<'p>(
&self,
py: Python<'p>,
generator: PyObject,
lease: Option<&PyLease>,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
self.event_loop.clone(),
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let mut builder = self.inner.endpoint_builder().handler(ingress);
if lease.is_some() {
builder = builder.lease(lease.map(|l| l.inner.clone()));
}
let builder = self.inner.endpoint_builder().handler(ingress);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?;
Ok(())
......
......@@ -73,19 +73,17 @@ impl KvMetricsPublisher {
})
}
#[pyo3(signature = (component, lease=None))]
#[pyo3(signature = (component))]
fn create_endpoint<'p>(
&self,
py: Python<'p>,
component: Component,
lease: Option<&PyLease>,
) -> PyResult<Bound<'p, PyAny>> {
let rs_publisher = self.inner.clone();
let rs_component = component.inner.clone();
let lease = lease.map(|l| l.inner.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
rs_publisher
.create_endpoint(rs_component, lease)
.create_endpoint(rs_component)
.await
.map_err(to_pyerr)?;
Ok(())
......
......@@ -49,31 +49,11 @@ class DistributedRuntime:
"""
...
class PyLease:
"""
A lease object
"""
def id(self) -> int:
"""
Return the id of the lease
Refer to https://etcd.io/docs/v3.4/learning/api/ for examples on how to use the lease id
"""
...
def revoke(self) -> None:
"""
Revoke the lease by triggering the cancellation token
This will invalidate the kv pairs associated with this lease
"""
...
def is_valid(self) -> bool:
def shutdown(self) -> None:
"""
Check if the lease is still valid (not revoked)
Shutdown the runtime by triggering the cancellation token
"""
...
class EtcdClient:
"""
Etcd is used for discovery in the DistributedRuntime
......@@ -232,14 +212,6 @@ class Component:
"""
...
def create_service_with_custom_lease(self, ttl: int) -> PyLease:
"""
Create a service with a custom lease
The lease needs to be tied to the endpoint of this services when creating the endpoints later
TODO: tie the lease to the service instead of the endpoint
"""
...
class Endpoint:
"""
An Endpoint is a single API endpoint
......
......@@ -23,7 +23,6 @@ use dynamo_runtime::{
SingleIn,
},
protocols::annotated::Annotated,
transports::etcd::Lease,
Error, Result,
};
use futures::stream;
......@@ -93,12 +92,12 @@ impl KvMetricsPublisher {
self.tx.send(metrics)
}
pub async fn create_endpoint(&self, component: Component, lease: Option<Lease>) -> Result<()> {
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let mut metrics_rx = self.rx.clone();
let handler = Arc::new(KvLoadEndpoingHander::new(metrics_rx.clone()));
let handler = Ingress::for_engine(handler)?;
let builder = component
component
.endpoint(KV_METRICS_ENDPOINT)
.endpoint_builder()
.stats_handler(move |_| {
......@@ -106,9 +105,8 @@ impl KvMetricsPublisher {
serde_json::to_value(&*metrics).unwrap()
})
.handler(handler)
.lease(lease);
builder.start().await
.start()
.await
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment