Unverified Commit 4c38680e authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: gracefully shutdown endpoint by revoking etcd lease + python binding (#730)


Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent a745a980
......@@ -158,6 +158,7 @@ class VllmWorker:
async def create_metrics_publisher_endpoint(self):
component = dynamo_context["component"]
# TODO: use the same child lease for metrics publisher endpoint and generate endpoint
await self.metrics_publisher.create_endpoint(component)
def get_remote_prefill_request_callback(self):
......@@ -171,6 +172,7 @@ class VllmWorker:
return callback
# TODO: use the same child lease for metrics publisher endpoint and generate endpoint
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely
......
......@@ -20,7 +20,7 @@ import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=True)
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo")
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
@dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo")
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` client and call the `generate` endpoint
"""
# get endpoint
endpoint = runtime.namespace(ns).component("backend").endpoint("generate")
# create client
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_endpoints()
# issue request
stream = await client.generate("hello world")
# process the stream
async for char in stream:
print(char)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -28,10 +28,11 @@ class RequestHandler:
async def generate(self, request):
print(f"Received request: {request}")
for char in request:
await asyncio.sleep(1)
yield char
@dynamo_worker(static=True)
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo")
......@@ -42,11 +43,16 @@ async def init(runtime: DistributedRuntime, ns: str):
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
await component.create_service()
lease = await component.create_service_with_custom_lease(ttl=1)
lease_id = lease.id()
print(f"Created custom lease with ID: {lease_id}/{lease_id:#x}")
endpoint = component.endpoint("generate")
print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate)
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint_with_lease(RequestHandler().generate, lease)
if __name__ == "__main__":
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_worker
class RequestHandler:
"""
Request handler for the generate endpoint
"""
async def generate(self, request):
print(f"Received request: {request}")
for char in request:
yield char
@dynamo_worker(static=True)
async def worker(runtime: DistributedRuntime):
await init(runtime, "dynamo")
async def init(runtime: DistributedRuntime, ns: str):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace(ns).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started server instance")
await endpoint.serve_endpoint(RequestHandler().generate)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
......@@ -149,6 +149,23 @@ struct Client {
inner: rs::component::Client<serde_json::Value, serde_json::Value>,
}
#[pyclass]
#[derive(Clone)]
struct PyLease {
inner: rs::transports::etcd::Lease,
}
#[pymethods]
impl PyLease {
fn id(&self) -> i64 {
self.inner.id()
}
fn revoke(&self) {
self.inner.revoke();
}
}
#[pymethods]
impl DistributedRuntime {
#[new]
......@@ -356,6 +373,41 @@ impl Component {
Ok(())
})
}
#[pyo3(signature = (ttl=1))]
fn create_service_with_custom_lease<'p>(
&self,
py: Python<'p>,
ttl: i64,
) -> PyResult<Bound<'p, PyAny>> {
let component = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Get the etcd client from the runtime
let etcd_client = component
.drt()
.etcd_client()
.ok_or_else(|| to_pyerr("etcd client not found"))?;
// Create a custom lease with the specified TTL
let custom_lease = etcd_client.create_lease(ttl).await.map_err(to_pyerr)?;
tracing::info!("created custom lease: {:?}", custom_lease);
// Create a service
// TODO: tie the lease to service instead of endpoint
let _service = component
.service_builder()
.create()
.await
.map_err(to_pyerr)?;
// Return the lease
Ok(PyLease {
inner: custom_lease,
})
})
}
}
#[pymethods]
......@@ -377,6 +429,33 @@ impl Endpoint {
})
}
fn serve_endpoint_with_lease<'p>(
&self,
py: Python<'p>,
generator: PyObject,
lease: &PyLease,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
self.event_loop.clone(),
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
// Create the builder with the ingress
let builder = self
.inner
.endpoint_builder()
.handler(ingress)
.lease(Some(lease.inner.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Start the endpoint
builder.start().await.map_err(to_pyerr)?;
Ok(())
})
}
fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
......
......@@ -49,6 +49,25 @@ class DistributedRuntime:
"""
...
class PyLease:
"""
A lease object
"""
def id(self) -> int:
"""
Return the id of the lease
Refer to https://etcd.io/docs/v3.4/learning/api/ for examples on how to use the lease id
"""
...
def revoke(self) -> None:
"""
Revoke the lease by triggering the cancellation token
This will invalidate the kv pairs associated with this lease
"""
...
class EtcdClient:
"""
Etcd is used for discovery in the DistributedRuntime
......@@ -76,6 +95,12 @@ class EtcdClient:
"""
...
async def revoke_lease(self, lease_id: int) -> None:
"""
Revoke a lease
"""
...
class EtcdKvCache:
"""
A cache for key-value pairs stored in etcd.
......@@ -175,6 +200,14 @@ class Component:
"""
...
def create_service_with_custom_lease(self, ttl: int) -> PyLease:
"""
Create a service with a custom lease
The lease needs to be tied to the endpoint of this services when creating the endpoints later
TODO: tie the lease to the service instead of the endpoint
"""
...
class Endpoint:
"""
An Endpoint is a single API endpoint
......
......@@ -13,10 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::atomic::{AtomicU64, Ordering};
use super::*;
use anyhow::Result;
use async_nats::service::endpoint::Endpoint;
use derive_builder::Builder;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
#[derive(Builder)]
......@@ -36,6 +39,9 @@ impl PushEndpoint {
pub async fn start(self, endpoint: Endpoint) -> Result<()> {
let mut endpoint = endpoint;
let inflight = Arc::new(AtomicU64::new(0));
let notify = Arc::new(Notify::new());
loop {
let req = tokio::select! {
biased;
......@@ -47,7 +53,7 @@ impl PushEndpoint {
// process shutdown
_ = self.cancellation_token.cancelled() => {
// tracing::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
tracing::info!("Shutting down service");
if let Err(e) = endpoint.stop().await {
tracing::warn!("Failed to stop NATS service: {:?}", e);
}
......@@ -63,6 +69,12 @@ impl PushEndpoint {
let ingress = self.service_handler.clone();
let worker_id = "".to_string();
// increment the inflight counter
inflight.fetch_add(1, Ordering::SeqCst);
let inflight_clone = inflight.clone();
let notify_clone = notify.clone();
tokio::spawn(async move {
tracing::trace!(worker_id, "handling new request");
let result = ingress.handle_payload(req.message.payload).await;
......@@ -74,12 +86,26 @@ impl PushEndpoint {
tracing::warn!("Failed to handle request: {:?}", e);
}
}
// decrease the inflight counter
inflight_clone.fetch_sub(1, Ordering::SeqCst);
notify_clone.notify_one();
});
} else {
break;
}
}
// await for all inflight requests to complete
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
);
while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!("All inflight requests completed");
Ok(())
}
}
......@@ -504,6 +504,7 @@ async fn tcp_listener(
}
_ = context.stopped(), if can_stop => {
tracing::trace!("context stop signal received; shutting down");
can_stop = false;
control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment