Unverified Commit 626d7e18 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

feat(request cancellation): pycontext, propagating the `is_stopped` into python land. (#2158)

parent 49958435
......@@ -1446,9 +1446,9 @@ dependencies = [
[[package]]
name = "cudarc"
version = "0.17.1"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "018e09f92e57618dbae5a3a0dcc2026547eed0e5b6a503a32c11ee1a94890830"
checksum = "8147ca46109d41cc513fd629b52bbea9bd09b972034c2f32954ce84a92895a91"
dependencies = [
"libloading",
]
......@@ -1924,7 +1924,7 @@ dependencies = [
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono",
"criterion",
"cudarc 0.17.1",
"cudarc 0.17.2",
"dashmap",
"derive-getters",
"derive_builder",
......
......@@ -18,7 +18,7 @@ import logging
import uvloop
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.runtime import DistributedRuntime, PyContext, dynamo_endpoint, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
logger = logging.getLogger(__name__)
......@@ -26,10 +26,13 @@ configure_dynamo_logging(service_name="backend")
@dynamo_endpoint(str, str)
async def content_generator(request: str):
logger.info(f"Received request: {request}")
async def content_generator(request: str, context: PyContext):
logger.info(f"Received request: {request} with `id={context.id()}`")
for word in request.split(","):
await asyncio.sleep(1)
if context.is_stopped() or context.is_killed():
print("request got cancelled.")
return
yield f"Hello {word}!"
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
}
impl PyContext {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}
}
#[pymethods]
impl PyContext {
// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped()
}
// sync method of `await async_is_killed()`
fn is_killed(&self) -> bool {
self.inner.is_killed()
}
// issues a stop generating
fn stop_generating(&self) {
self.inner.stop_generating();
}
fn id(&self) -> &str {
self.inner.id()
}
// allows building a async callback.
fn async_killed_or_stopped<'a>(&self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
tokio::select! {
_ = inner.killed() => {
Ok(true)
}
_ = inner.stopped() => {
Ok(true)
}
}
})
}
}
// PyO3 equivalent for verify if signature contains target_name
// def callable_accepts_kwarg(target_name: str):
// import inspect
// return target_name in inspect.signature(func).parameters
pub fn callable_accepts_kwarg(
py: Python,
callable: &Bound<'_, PyAny>,
target_name: &str,
) -> PyResult<bool> {
let inspect: Bound<'_, PyModule> = py.import("inspect")?;
let signature = inspect.call_method1("signature", (callable,))?;
let params_any: Bound<'_, PyAny> = signature.getattr("parameters")?;
params_any
.call_method1("__contains__", (target_name,))?
.extract::<bool>()
}
......@@ -13,11 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use super::context::{callable_accepts_kwarg, PyContext};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
use pyo3_async_runtimes::TaskLocals;
use pythonize::{depythonize, pythonize};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::{wrappers::ReceiverStream, StreamExt};
......@@ -36,7 +38,6 @@ pub fn add_to_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PythonAsyncEngine>()?;
Ok(())
}
// todos:
// - [ ] enable context cancellation
// - this will likely require a change to the function signature python calling arguments
......@@ -113,6 +114,7 @@ pub struct PythonServerStreamingEngine {
_cancel_token: CancellationToken,
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
has_pycontext: bool,
}
impl PythonServerStreamingEngine {
......@@ -121,10 +123,16 @@ impl PythonServerStreamingEngine {
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
) -> Self {
let has_pycontext = Python::with_gil(|py| {
let callable = generator.bind(py);
callable_accepts_kwarg(py, callable, "context").unwrap_or(false)
});
PythonServerStreamingEngine {
_cancel_token: cancel_token,
generator,
event_loop,
has_pycontext,
}
}
}
......@@ -166,6 +174,8 @@ where
let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();
let has_pycontext = self.has_pycontext;
// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
......@@ -180,7 +190,18 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let gen = generator.call1(py, (py_request,))?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;
let gen = if has_pycontext {
// Pass context as a kwarg
let kwarg = PyDict::new(py);
kwarg.set_item("context", &py_ctx)?;
generator.call(py, (py_request,), Some(&kwarg))
} else {
// Legacy: No `context` arg
generator.call1(py, (py_request,))
}?;
let locals = TaskLocals::new(event_loop.bind(py).clone());
pyo3_async_runtimes::tokio::into_stream_with_locals_v1(locals, gen.into_bound(py))
})
......
......@@ -45,6 +45,7 @@ impl From<RouterMode> for RsRouterMode {
}
}
mod context;
mod engine;
mod http;
mod llm;
......@@ -103,6 +104,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
......
......@@ -30,6 +30,7 @@ from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
from dynamo._core import PyContext as PyContext
def dynamo_worker(static=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment