Unverified Commit 42d69805 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Remove public uses of CancellationToken (#6405)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent bb8fc8a4
......@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime):
service.add_completions_model(model_name, checksum, engine)
print("Starting KServe gRPC service...")
shutdown_signal = service.run(runtime.child_token())
shutdown_signal = service.run(runtime)
try:
print(
......@@ -86,6 +86,7 @@ async def worker(runtime: DistributedRuntime):
print(f"Unexpected error occurred: {err}")
finally:
print("Shutting down worker...")
service.shutdown() # Shutdown service first
runtime.shutdown()
......
......@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime):
service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...")
shutdown_signal = service.run(runtime.child_token())
shutdown_signal = service.run(runtime)
try:
print(f"Serving endpoint: {host}:{port}/v1/models")
......@@ -88,6 +88,7 @@ async def worker(runtime: DistributedRuntime):
print(f"Unexpected error occurred: {e}")
finally:
print("Shutting down worker...")
service.shutdown() # Shutdown service first
runtime.shutdown()
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use anyhow::{Error, Result, anyhow as error};
use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, to_pyerr};
use crate::{CancellationToken, DistributedRuntime, engine::*, to_pyerr};
pub use dynamo_llm::endpoint_type::EndpointType;
pub use dynamo_llm::http::service::{error as http_error, service_v2};
......@@ -18,6 +18,8 @@ pub use dynamo_runtime::{
#[pyclass]
pub struct HttpService {
inner: service_v2::HttpService,
// CancellationToken is already Send + Sync + Clone, no Mutex needed
cancel_token: Arc<OnceLock<CancellationToken>>,
}
#[pymethods]
......@@ -27,7 +29,10 @@ impl HttpService {
pub fn new(port: Option<u16>) -> PyResult<Self> {
let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080));
let inner = builder.build().map_err(to_pyerr)?;
Ok(Self { inner })
Ok(Self {
inner,
cancel_token: Arc::new(OnceLock::new()),
})
}
pub fn add_completions_model(
......@@ -78,14 +83,42 @@ impl HttpService {
Ok(self.inner.model_manager().list_completions_models())
}
fn run<'p>(&self, py: Python<'p>, token: CancellationToken) -> PyResult<Bound<'p, PyAny>> {
fn run<'p>(&self, py: Python<'p>, runtime: &DistributedRuntime) -> PyResult<Bound<'p, PyAny>> {
// Check if run() was already called to avoid creating unnecessary token
if self.cancel_token.get().is_some() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"HttpService.run() has already been called on this instance",
));
}
let service = self.inner.clone();
// Only create token if we passed the check above
let token = runtime.inner().child_token();
// Store the token for shutdown - should always succeed after the check above
self.cancel_token
.set(CancellationToken {
inner: token.clone(),
})
.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Race condition detected in HttpService.run()",
)
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
service.run(token.inner).await.map_err(to_pyerr)?;
service.run(token).await.map_err(to_pyerr)?;
Ok(())
})
}
fn shutdown(&self) {
// CancellationToken.cancel() is thread-safe, no lock needed
if let Some(token) = self.cancel_token.get() {
token.inner.cancel();
}
}
fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
let endpoint_type = EndpointType::all()
.iter()
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use dynamo_llm::{self as llm_rs};
use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use llm_rs::model_type::{ModelInput, ModelType};
use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, llm::local_model::ModelRuntimeConfig, to_pyerr};
use crate::{
CancellationToken, DistributedRuntime, engine::*, llm::local_model::ModelRuntimeConfig,
to_pyerr,
};
pub use dynamo_llm::grpc::service::kserve;
#[pyclass]
pub struct KserveGrpcService {
inner: kserve::KserveService,
// CancellationToken is already Send + Sync + Clone, no Mutex needed
cancel_token: Arc<OnceLock<CancellationToken>>,
}
#[pymethods]
......@@ -30,7 +35,10 @@ impl KserveGrpcService {
builder = builder.host(host);
}
let inner = builder.build().map_err(to_pyerr)?;
Ok(Self { inner })
Ok(Self {
inner,
cancel_token: Arc::new(OnceLock::new()),
})
}
pub fn add_completions_model(
......@@ -128,11 +136,39 @@ impl KserveGrpcService {
Ok(self.inner.model_manager().list_tensor_models())
}
fn run<'p>(&self, py: Python<'p>, token: CancellationToken) -> PyResult<Bound<'p, PyAny>> {
fn run<'p>(&self, py: Python<'p>, runtime: &DistributedRuntime) -> PyResult<Bound<'p, PyAny>> {
// Check if run() was already called to avoid creating unnecessary token
if self.cancel_token.get().is_some() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"KserveGrpcService.run() has already been called on this instance",
));
}
let service = self.inner.clone();
// Only create token if we passed the check above
let token = runtime.inner().child_token();
// Store the token for shutdown - should always succeed after the check above
self.cancel_token
.set(CancellationToken {
inner: token.clone(),
})
.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Race condition detected in KserveGrpcService.run()",
)
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
service.run(token.inner).await.map_err(to_pyerr)?;
service.run(token).await.map_err(to_pyerr)?;
Ok(())
})
}
fn shutdown(&self) {
// CancellationToken.cancel() is thread-safe, no lock needed
if let Some(token) = self.cancel_token.get() {
token.inner.cancel();
}
}
}
......@@ -149,7 +149,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?;
m.add_class::<Component>()?;
m.add_class::<Endpoint>()?;
m.add_class::<ModelCardInstanceId>()?;
......@@ -680,11 +679,6 @@ impl DistributedRuntime {
self.event_loop.clone()
}
fn child_token(&self) -> CancellationToken {
let inner = self.inner.runtime().child_token();
CancellationToken { inner }
}
/// Register an async Python callback for /engine/{route_name}
///
/// Args:
......@@ -779,21 +773,6 @@ impl DistributedRuntime {
}
}
#[pymethods]
impl CancellationToken {
fn cancel(&self) {
self.inner.cancel();
}
fn cancelled<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let token = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
token.cancelled().await;
Ok(())
})
}
}
#[pymethods]
impl Component {
fn endpoint(&self, name: String) -> PyResult<Endpoint> {
......
......@@ -85,12 +85,6 @@ class DistributedRuntime:
"""
...
def child_token(self) -> CancellationToken:
"""
Get a child cancellation token that can be passed to async tasks
"""
...
def register_engine_route(
self,
route_name: str,
......@@ -117,20 +111,6 @@ class DistributedRuntime:
"""
...
class CancellationToken:
def cancel(self) -> None:
"""
Cancel the token and all its children
"""
...
async def cancelled(self) -> None:
"""
Await until the token is cancelled
"""
...
class Component:
"""
A component is a collection of endpoints
......@@ -784,7 +764,29 @@ class HttpService:
It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime.
"""
...
def __init__(self, port: Optional[int] = None) -> None:
"""
Create a new HTTP service.
Args:
port: Optional port number to bind the service to (default: 8080)
"""
...
async def run(self, runtime: DistributedRuntime) -> None:
"""
Run the HTTP service.
Args:
runtime: DistributedRuntime instance for token management
"""
...
def shutdown(self) -> None:
"""
Shutdown the HTTP service by cancelling its internal token.
"""
...
class PythonAsyncEngine:
"""
......@@ -925,12 +927,18 @@ class KserveGrpcService:
"""
...
async def run(self, token: CancellationToken) -> None:
async def run(self, runtime: DistributedRuntime) -> None:
"""
Run the KServe gRPC service.
Args:
token: Cancellation token to stop the service
runtime: DistributedRuntime instance for token management
"""
...
def shutdown(self) -> None:
"""
Shutdown the KServe gRPC service by cancelling its internal token.
"""
...
......
......@@ -84,8 +84,8 @@ async def http_server(runtime: DistributedRuntime):
port = 8008
model_name = "test_model"
start_done = asyncio.Event()
child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model
service = HttpService(port=port) # Create service outside worker so we can shutdown
async def worker():
"""The server worker task."""
......@@ -94,11 +94,10 @@ async def http_server(runtime: DistributedRuntime):
python_engine = MockHttpEngine(model_name)
engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port)
service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token)
shutdown_signal = service.run(runtime)
print("Starting service on port", port)
start_done.set()
await shutdown_signal
......@@ -106,8 +105,6 @@ async def http_server(runtime: DistributedRuntime):
print("Server encountered an error:", e)
start_done.set()
raise ValueError(f"Server failed to start: {e}")
finally:
child_token.cancel()
server_task = asyncio.create_task(worker())
await asyncio.wait_for(start_done.wait(), timeout=30.0)
......@@ -116,7 +113,7 @@ async def http_server(runtime: DistributedRuntime):
yield f"http://localhost:{port}", model_name
# Teardown: Cancel the server task if it's still running
child_token.cancel()
service.shutdown() # Shutdown service
await asyncio.sleep(0.1) # Give some time for graceful shutdown
if not server_task.done():
server_task.cancel()
......
......@@ -68,17 +68,15 @@ def tensor_service(runtime):
model_name, checksum, engine, runtime_config=runtime_config
)
cancel_token = runtime.child_token()
async def _serve():
await tensor_model_service.run(cancel_token)
await tensor_model_service.run(runtime)
server_task = asyncio.create_task(_serve())
try:
await asyncio.sleep(1) # wait service to start
yield host, port
finally:
cancel_token.cancel()
tensor_model_service.shutdown()
with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(server_task, timeout=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment