Unverified Commit 734d2f87 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Enable cancellation during or before a stream is established (#3635)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent f978f4d1
......@@ -21,7 +21,7 @@ use tracing::Instrument;
use dynamo_runtime::{
self as rs, logging,
pipeline::{
EngineStream, ManyOut, SingleIn, context::Context as RsContext,
AsyncEngineContextProvider, EngineStream, ManyOut, SingleIn, context::Context as RsContext,
network::egress::push_router::RouterMode as RsRouterMode,
},
protocols::annotated::Annotated as RsAnnotated,
......@@ -90,6 +90,29 @@ fn get_span_for_direct_context(
)
}
// Helper to create request context with proper linking and cancellation handling
fn create_request_context(
request: serde_json::Value,
parent_ctx: &Option<context::Context>,
) -> RsContext<serde_json::Value> {
match parent_ctx {
// If there is a parent context, link the request as a child context of it
Some(parent_ctx) => {
let child_ctx = RsContext::with_id(request, parent_ctx.inner().id().to_string());
parent_ctx.inner().link_child(child_ctx.context());
if parent_ctx.inner().is_stopped() || parent_ctx.inner().is_killed() {
// Let the server handle the cancellation for now since not all backends are
// properly handling request exceptions
// TODO: (DIS-830) Return an error if context is cancelled
child_ctx.context().stop_generating();
}
child_ctx
}
// Otherwise if there is no parent context, use the request as-is
_ => request.into(),
}
}
/// A Python module implemented in Rust. The name of this function must match
/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
/// import the module.
......@@ -794,6 +817,7 @@ impl Client {
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let request_ctx = create_request_context(request, &context);
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
......@@ -802,17 +826,15 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
// Always instrument with appropriate span (none if no trace context)
let span = get_span_for_context(&context, "round_robin");
let stream_future = client.round_robin(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
client
.round_robin(request_ctx)
.instrument(span)
.await
.map_err(to_pyerr)?
}
_ => client.round_robin(request.into()).await.map_err(to_pyerr)?,
_ => client.round_robin(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
......@@ -832,6 +854,7 @@ impl Client {
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let request_ctx = create_request_context(request, &context);
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
......@@ -840,16 +863,15 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
// Always instrument with appropriate span (none if no trace context)
let span = get_span_for_context(&context, "random");
let stream_future = client.random(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
client
.random(request_ctx)
.instrument(span)
.await
.map_err(to_pyerr)?
}
_ => client.random(request.into()).await.map_err(to_pyerr)?,
_ => client.random(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
......@@ -870,6 +892,7 @@ impl Client {
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let request_ctx = create_request_context(request, &context);
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
......@@ -878,18 +901,17 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
// Always instrument with appropriate span (none if no trace context)
let span =
get_span_for_direct_context(&context, "direct", &instance_id.to_string());
let stream_future = client.direct(request_ctx, instance_id).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
client
.direct(request_ctx, instance_id)
.instrument(span)
.await
.map_err(to_pyerr)?
}
_ => client
.direct(request.into(), instance_id)
.direct(request_ctx, instance_id)
.await
.map_err(to_pyerr)?,
};
......@@ -913,6 +935,7 @@ impl Client {
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let request_ctx = create_request_context(request, &context);
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
......@@ -921,16 +944,15 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = match context {
Some(context) => {
let request_ctx = RsContext::with_id(request, context.inner().id().to_string());
// Always instrument with appropriate span (none if no trace context)
let span = get_span_for_context(&context, "static");
let stream_future = client.r#static(request_ctx).instrument(span);
let stream = stream_future.await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
client
.r#static(request_ctx)
.instrument(span)
.await
.map_err(to_pyerr)?
}
_ => client.r#static(request.into()).await.map_err(to_pyerr)?,
_ => client.r#static(request_ctx).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
......
......@@ -289,3 +289,48 @@ async def test_server_raise_cancelled(server, client):
# TODO: Server to gracefully stop the stream?
assert not handler.context_is_stopped
assert not handler.context_is_killed
@pytest.mark.forked
@pytest.mark.asyncio
async def test_client_context_already_cancelled(server, client):
_, handler = server
context = Context()
context.stop_generating()
# TODO: (DIS-830) The outgoing call should raise if context is cancelled
stream = await client.generate("_generate_until_context_cancelled", context=context)
async for _ in stream:
raise AssertionError(
"Request should be cancelled before any responses are generated"
)
# Give server a moment to update status
await asyncio.sleep(0.2)
# Verify server context cancellation status
assert handler.context_is_stopped
assert not handler.context_is_killed
@pytest.mark.forked
@pytest.mark.asyncio
async def test_client_context_cancel_before_await_request(server, client):
_, handler = server
context = Context()
request = client.generate("_generate_until_context_cancelled", context=context)
context.stop_generating()
# TODO: (DIS-830) The outgoing call should raise if context is cancelled
stream = await request
async for _ in stream:
raise AssertionError(
"Request should be cancelled before any responses are generated"
)
# Give server a moment to update status
await asyncio.sleep(0.2)
# Verify server context cancellation status
assert handler.context_is_stopped
assert not handler.context_is_killed
......@@ -133,6 +133,13 @@ impl RetryManager {
self.retries_left -= 1;
let request = Context::with_id(self.request.clone(), self.context.id().to_string());
self.context.link_child(request.context());
if self.context.is_stopped() || self.context.is_killed() {
tracing::debug!("Abort creating new stream after context is stopped or killed");
return Err(Error::msg(format!(
"Context id {} is stopped or killed",
self.context.id()
)));
}
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
......@@ -715,4 +722,40 @@ mod tests {
assert!(error.to_string().contains(STREAM_ERR_MSG));
}
}
/// Test case 7: Request cancelled when creating new stream
/// Tests the scenario where context.stop_generating() is called when creating a new stream.
/// The RetryManager should detect that the context is stopped and abort creating new streams.
/// Expected behavior: Should fail to build RetryManager with "Context is stopped or killed" error.
#[tokio::test]
async fn test_retry_manager_context_stopped_before_stream() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::Success,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
// Stop the context before building RetryManager
ctx.stop_generating();
// Should fail to build due to stopped context
let retry_manager_result = RetryManager::build(ctx, request, next_generate, 3).await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
assert!(
error
.to_string()
.contains(&format!("Context id {} is stopped or killed", context_id))
);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment