Unverified Commit 42d69805 authored by Tzu-Ling Kan's avatar Tzu-Ling Kan Committed by GitHub
Browse files

feat: Remove public uses of CancellationToken (#6405)


Signed-off-by: default avatartzulingk@nvidia.com <tzulingk@nvidia.com>
parent bb8fc8a4
...@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime):
service.add_completions_model(model_name, checksum, engine) service.add_completions_model(model_name, checksum, engine)
print("Starting KServe gRPC service...") print("Starting KServe gRPC service...")
shutdown_signal = service.run(runtime.child_token()) shutdown_signal = service.run(runtime)
try: try:
print( print(
...@@ -86,6 +86,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -86,6 +86,7 @@ async def worker(runtime: DistributedRuntime):
print(f"Unexpected error occurred: {err}") print(f"Unexpected error occurred: {err}")
finally: finally:
print("Shutting down worker...") print("Shutting down worker...")
service.shutdown() # Shutdown service first
runtime.shutdown() runtime.shutdown()
......
...@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -72,7 +72,7 @@ async def worker(runtime: DistributedRuntime):
service.add_chat_completions_model(served_model_name, "mdcsum", engine) service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...") print("Starting service...")
shutdown_signal = service.run(runtime.child_token()) shutdown_signal = service.run(runtime)
try: try:
print(f"Serving endpoint: {host}:{port}/v1/models") print(f"Serving endpoint: {host}:{port}/v1/models")
...@@ -88,6 +88,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -88,6 +88,7 @@ async def worker(runtime: DistributedRuntime):
print(f"Unexpected error occurred: {e}") print(f"Unexpected error occurred: {e}")
finally: finally:
print("Shutting down worker...") print("Shutting down worker...")
service.shutdown() # Shutdown service first
runtime.shutdown() runtime.shutdown()
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use anyhow::{Error, Result, anyhow as error}; use anyhow::{Error, Result, anyhow as error};
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, to_pyerr}; use crate::{CancellationToken, DistributedRuntime, engine::*, to_pyerr};
pub use dynamo_llm::endpoint_type::EndpointType; pub use dynamo_llm::endpoint_type::EndpointType;
pub use dynamo_llm::http::service::{error as http_error, service_v2}; pub use dynamo_llm::http::service::{error as http_error, service_v2};
...@@ -18,6 +18,8 @@ pub use dynamo_runtime::{ ...@@ -18,6 +18,8 @@ pub use dynamo_runtime::{
#[pyclass] #[pyclass]
pub struct HttpService { pub struct HttpService {
inner: service_v2::HttpService, inner: service_v2::HttpService,
// CancellationToken is already Send + Sync + Clone, no Mutex needed
cancel_token: Arc<OnceLock<CancellationToken>>,
} }
#[pymethods] #[pymethods]
...@@ -27,7 +29,10 @@ impl HttpService { ...@@ -27,7 +29,10 @@ impl HttpService {
pub fn new(port: Option<u16>) -> PyResult<Self> { pub fn new(port: Option<u16>) -> PyResult<Self> {
let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080)); let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080));
let inner = builder.build().map_err(to_pyerr)?; let inner = builder.build().map_err(to_pyerr)?;
Ok(Self { inner }) Ok(Self {
inner,
cancel_token: Arc::new(OnceLock::new()),
})
} }
pub fn add_completions_model( pub fn add_completions_model(
...@@ -78,14 +83,42 @@ impl HttpService { ...@@ -78,14 +83,42 @@ impl HttpService {
Ok(self.inner.model_manager().list_completions_models()) Ok(self.inner.model_manager().list_completions_models())
} }
fn run<'p>(&self, py: Python<'p>, token: CancellationToken) -> PyResult<Bound<'p, PyAny>> { fn run<'p>(&self, py: Python<'p>, runtime: &DistributedRuntime) -> PyResult<Bound<'p, PyAny>> {
// Check if run() was already called to avoid creating unnecessary token
if self.cancel_token.get().is_some() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"HttpService.run() has already been called on this instance",
));
}
let service = self.inner.clone(); let service = self.inner.clone();
// Only create token if we passed the check above
let token = runtime.inner().child_token();
// Store the token for shutdown - should always succeed after the check above
self.cancel_token
.set(CancellationToken {
inner: token.clone(),
})
.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Race condition detected in HttpService.run()",
)
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
service.run(token.inner).await.map_err(to_pyerr)?; service.run(token).await.map_err(to_pyerr)?;
Ok(()) Ok(())
}) })
} }
fn shutdown(&self) {
// CancellationToken.cancel() is thread-safe, no lock needed
if let Some(token) = self.cancel_token.get() {
token.inner.cancel();
}
}
fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> { fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
let endpoint_type = EndpointType::all() let endpoint_type = EndpointType::all()
.iter() .iter()
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use dynamo_llm::{self as llm_rs}; use dynamo_llm::{self as llm_rs};
use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard; use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use llm_rs::model_type::{ModelInput, ModelType}; use llm_rs::model_type::{ModelInput, ModelType};
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, llm::local_model::ModelRuntimeConfig, to_pyerr}; use crate::{
CancellationToken, DistributedRuntime, engine::*, llm::local_model::ModelRuntimeConfig,
to_pyerr,
};
pub use dynamo_llm::grpc::service::kserve; pub use dynamo_llm::grpc::service::kserve;
#[pyclass] #[pyclass]
pub struct KserveGrpcService { pub struct KserveGrpcService {
inner: kserve::KserveService, inner: kserve::KserveService,
// CancellationToken is already Send + Sync + Clone, no Mutex needed
cancel_token: Arc<OnceLock<CancellationToken>>,
} }
#[pymethods] #[pymethods]
...@@ -30,7 +35,10 @@ impl KserveGrpcService { ...@@ -30,7 +35,10 @@ impl KserveGrpcService {
builder = builder.host(host); builder = builder.host(host);
} }
let inner = builder.build().map_err(to_pyerr)?; let inner = builder.build().map_err(to_pyerr)?;
Ok(Self { inner }) Ok(Self {
inner,
cancel_token: Arc::new(OnceLock::new()),
})
} }
pub fn add_completions_model( pub fn add_completions_model(
...@@ -128,11 +136,39 @@ impl KserveGrpcService { ...@@ -128,11 +136,39 @@ impl KserveGrpcService {
Ok(self.inner.model_manager().list_tensor_models()) Ok(self.inner.model_manager().list_tensor_models())
} }
fn run<'p>(&self, py: Python<'p>, token: CancellationToken) -> PyResult<Bound<'p, PyAny>> { fn run<'p>(&self, py: Python<'p>, runtime: &DistributedRuntime) -> PyResult<Bound<'p, PyAny>> {
// Check if run() was already called to avoid creating unnecessary token
if self.cancel_token.get().is_some() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"KserveGrpcService.run() has already been called on this instance",
));
}
let service = self.inner.clone(); let service = self.inner.clone();
// Only create token if we passed the check above
let token = runtime.inner().child_token();
// Store the token for shutdown - should always succeed after the check above
self.cancel_token
.set(CancellationToken {
inner: token.clone(),
})
.map_err(|_| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Race condition detected in KserveGrpcService.run()",
)
})?;
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
service.run(token.inner).await.map_err(to_pyerr)?; service.run(token).await.map_err(to_pyerr)?;
Ok(()) Ok(())
}) })
} }
fn shutdown(&self) {
// CancellationToken.cancel() is thread-safe, no lock needed
if let Some(token) = self.cancel_token.get() {
token.inner.cancel();
}
}
} }
...@@ -149,7 +149,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -149,7 +149,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?; m.add_function(wrap_pyfunction!(llm::entrypoint::run_input, m)?)?;
m.add_class::<DistributedRuntime>()?; m.add_class::<DistributedRuntime>()?;
m.add_class::<CancellationToken>()?;
m.add_class::<Component>()?; m.add_class::<Component>()?;
m.add_class::<Endpoint>()?; m.add_class::<Endpoint>()?;
m.add_class::<ModelCardInstanceId>()?; m.add_class::<ModelCardInstanceId>()?;
...@@ -680,11 +679,6 @@ impl DistributedRuntime { ...@@ -680,11 +679,6 @@ impl DistributedRuntime {
self.event_loop.clone() self.event_loop.clone()
} }
fn child_token(&self) -> CancellationToken {
let inner = self.inner.runtime().child_token();
CancellationToken { inner }
}
/// Register an async Python callback for /engine/{route_name} /// Register an async Python callback for /engine/{route_name}
/// ///
/// Args: /// Args:
...@@ -779,21 +773,6 @@ impl DistributedRuntime { ...@@ -779,21 +773,6 @@ impl DistributedRuntime {
} }
} }
#[pymethods]
impl CancellationToken {
fn cancel(&self) {
self.inner.cancel();
}
fn cancelled<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let token = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
token.cancelled().await;
Ok(())
})
}
}
#[pymethods] #[pymethods]
impl Component { impl Component {
fn endpoint(&self, name: String) -> PyResult<Endpoint> { fn endpoint(&self, name: String) -> PyResult<Endpoint> {
......
...@@ -85,12 +85,6 @@ class DistributedRuntime: ...@@ -85,12 +85,6 @@ class DistributedRuntime:
""" """
... ...
def child_token(self) -> CancellationToken:
"""
Get a child cancellation token that can be passed to async tasks
"""
...
def register_engine_route( def register_engine_route(
self, self,
route_name: str, route_name: str,
...@@ -117,20 +111,6 @@ class DistributedRuntime: ...@@ -117,20 +111,6 @@ class DistributedRuntime:
""" """
... ...
class CancellationToken:
def cancel(self) -> None:
"""
Cancel the token and all its children
"""
...
async def cancelled(self) -> None:
"""
Await until the token is cancelled
"""
...
class Component: class Component:
""" """
A component is a collection of endpoints A component is a collection of endpoints
...@@ -784,7 +764,29 @@ class HttpService: ...@@ -784,7 +764,29 @@ class HttpService:
It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime. It is a OpenAI compatible http ingress into the Dynamo Distributed Runtime.
""" """
... def __init__(self, port: Optional[int] = None) -> None:
"""
Create a new HTTP service.
Args:
port: Optional port number to bind the service to (default: 8080)
"""
...
async def run(self, runtime: DistributedRuntime) -> None:
"""
Run the HTTP service.
Args:
runtime: DistributedRuntime instance for token management
"""
...
def shutdown(self) -> None:
"""
Shutdown the HTTP service by cancelling its internal token.
"""
...
class PythonAsyncEngine: class PythonAsyncEngine:
""" """
...@@ -925,12 +927,18 @@ class KserveGrpcService: ...@@ -925,12 +927,18 @@ class KserveGrpcService:
""" """
... ...
async def run(self, token: CancellationToken) -> None: async def run(self, runtime: DistributedRuntime) -> None:
""" """
Run the KServe gRPC service. Run the KServe gRPC service.
Args: Args:
token: Cancellation token to stop the service runtime: DistributedRuntime instance for token management
"""
...
def shutdown(self) -> None:
"""
Shutdown the KServe gRPC service by cancelling its internal token.
""" """
... ...
......
...@@ -84,8 +84,8 @@ async def http_server(runtime: DistributedRuntime): ...@@ -84,8 +84,8 @@ async def http_server(runtime: DistributedRuntime):
port = 8008 port = 8008
model_name = "test_model" model_name = "test_model"
start_done = asyncio.Event() start_done = asyncio.Event()
child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model checksum = "abc123" # Checksum of ModelDeplomentCard for that model
service = HttpService(port=port) # Create service outside worker so we can shutdown
async def worker(): async def worker():
"""The server worker task.""" """The server worker task."""
...@@ -94,11 +94,10 @@ async def http_server(runtime: DistributedRuntime): ...@@ -94,11 +94,10 @@ async def http_server(runtime: DistributedRuntime):
python_engine = MockHttpEngine(model_name) python_engine = MockHttpEngine(model_name)
engine = HttpAsyncEngine(python_engine.generate, loop) engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port)
service.add_chat_completions_model(model_name, checksum, engine) service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True) service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token) shutdown_signal = service.run(runtime)
print("Starting service on port", port) print("Starting service on port", port)
start_done.set() start_done.set()
await shutdown_signal await shutdown_signal
...@@ -106,8 +105,6 @@ async def http_server(runtime: DistributedRuntime): ...@@ -106,8 +105,6 @@ async def http_server(runtime: DistributedRuntime):
print("Server encountered an error:", e) print("Server encountered an error:", e)
start_done.set() start_done.set()
raise ValueError(f"Server failed to start: {e}") raise ValueError(f"Server failed to start: {e}")
finally:
child_token.cancel()
server_task = asyncio.create_task(worker()) server_task = asyncio.create_task(worker())
await asyncio.wait_for(start_done.wait(), timeout=30.0) await asyncio.wait_for(start_done.wait(), timeout=30.0)
...@@ -116,7 +113,7 @@ async def http_server(runtime: DistributedRuntime): ...@@ -116,7 +113,7 @@ async def http_server(runtime: DistributedRuntime):
yield f"http://localhost:{port}", model_name yield f"http://localhost:{port}", model_name
# Teardown: Cancel the server task if it's still running # Teardown: Cancel the server task if it's still running
child_token.cancel() service.shutdown() # Shutdown service
await asyncio.sleep(0.1) # Give some time for graceful shutdown await asyncio.sleep(0.1) # Give some time for graceful shutdown
if not server_task.done(): if not server_task.done():
server_task.cancel() server_task.cancel()
......
...@@ -68,17 +68,15 @@ def tensor_service(runtime): ...@@ -68,17 +68,15 @@ def tensor_service(runtime):
model_name, checksum, engine, runtime_config=runtime_config model_name, checksum, engine, runtime_config=runtime_config
) )
cancel_token = runtime.child_token()
async def _serve(): async def _serve():
await tensor_model_service.run(cancel_token) await tensor_model_service.run(runtime)
server_task = asyncio.create_task(_serve()) server_task = asyncio.create_task(_serve())
try: try:
await asyncio.sleep(1) # wait service to start await asyncio.sleep(1) # wait service to start
yield host, port yield host, port
finally: finally:
cancel_token.cancel() tensor_model_service.shutdown()
with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError): with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
await asyncio.wait_for(server_task, timeout=5) await asyncio.wait_for(server_task, timeout=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment