Unverified Commit 6c539fbd authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: FT Request Cancellation feature and test for 0.5.0 (#2500)

parent 7fabe7bf
......@@ -32,7 +32,7 @@ class BaseWorkerHandler(ABC):
self.kv_publisher = None
@abstractmethod
async def generate(self, request) -> AsyncGenerator[dict, None]:
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError
async def clear_kv_blocks(self, request=None):
......@@ -110,7 +110,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self._prefill_check_task.cancel()
super().cleanup()
async def generate(self, request):
async def generate(self, request, context):
request_id = str(uuid.uuid4().hex)
logger.debug(f"New Request ID: {request_id}")
......@@ -147,9 +147,20 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# TODO Change to prefill queue
if self.prefill_worker_client is not None:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(prefill_request)
)
try:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return
raise e
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
......@@ -162,6 +173,12 @@ class DecodeWorkerHandler(BaseWorkerHandler):
] = prefill_response.kv_transfer_params
async for tok in self.generate_tokens(prompt, sampling_params, request_id):
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
yield tok
......@@ -169,7 +186,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component, engine, default_sampling_params):
super().__init__(component, engine, default_sampling_params)
async def generate(self, request):
async def generate(self, request, context):
request_id = request["request_id"]
logger.debug(f"New Prefill Request ID: {request_id}")
......@@ -181,6 +198,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
# Generate only 1 token in prefill
try:
async for res in gen:
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
request_id=res.request_id,
......
# Request Cancellation Architecture
This document describes how Dynamo implements request cancellation to cancel in-flight requests between Dynamo workers. Request cancellation allows in-flight requests to terminate early, saving computational resources that would otherwise be spent on responses that are no longer needed.
## AsyncEngineContext Trait
At the core of Dynamo's request cancellation system is the `AsyncEngineContext` trait. This trait is associated with every request stream and provides lifecycle management for async operations, including stream identification, graceful shutdown capabilities, and immediate termination capabilities.
### Key Methods
#### Identification
- **`id()`**: Returns the unique identifier for the stream. This ID is set by the user for request identification, and the same ID can be used for sub-requests to associate them with the original user request.
#### Status Checking
- **`is_stopped()`**: Returns `true` if graceful cancellation has been requested via `stop_generating()`. This represents a signal to the worker that the request has been cancelled and it should return early.
- **`is_killed()`**: Returns `true` if a hard stop has been issued via `kill()`. This typically indicates that the network connection between client and server has been cut or an immediate termination is required.
#### Async Status Monitoring
- **`stopped()`**: An async method that completes when the context becomes stopped. If already stopped, returns immediately.
- **`killed()`**: An async method that completes when the context becomes killed. If already killed, returns immediately.
#### Cancellation Control
- **`stop_generating()`**: The recommended method for cancelling a request. This informs the engine to stop producing results for the stream gracefully. This method is idempotent and does not invalidate results currently in the stream.
- **`stop()`**: Alias for `stop_generating()`.
- **`kill()`**: Extends `stop_generating()` but also indicates a preference to terminate without draining remaining items in the stream. This is implementation-specific and may not be supported by all engines.
#### Child Request Management
- **`link_child(child: Arc<dyn AsyncEngineContext>)`**: Links a child `AsyncEngineContext` to this context. When `stop_generating()`, `stop()`, or `kill()` is called on the parent context, the same method is automatically called on all linked child contexts in the order they were linked. This is especially useful in disaggregated serving scenarios where a frontend receives cancellation notification and needs to cancel requests to workers, and the worker can then cancel its sub-requests (e.g., remote prefill operations).
### Thread Safety
The `AsyncEngineContext` trait ensures thread-safety with `Send + Sync` bounds, allowing safe concurrent access across multiple threads and async tasks.
## Python Bindings
The `AsyncEngineContext` functionality is exposed to Python through the `Context` class, which provides a largely one-to-one mapping from Rust methods to Python methods.
### Python Context Class
The Python `Context` class wraps the Rust `AsyncEngineContext` and exposes the following methods:
- **`id()`**: Returns the unique identifier for the context
- **`is_stopped()`**: Synchronous method equivalent to the Rust `is_stopped()`
- **`is_killed()`**: Synchronous method equivalent to the Rust `is_killed()`
- **`stop_generating()`**: Issues a stop generating signal, equivalent to the Rust method
- **`async_killed_or_stopped()`**: An async method that completes when the context becomes either killed or stopped, whichever happens first. This combines the functionality of the Rust `killed()` and `stopped()` async methods using `tokio::select!`.
### Context Usage in Python
The context is available optionally in both incoming and outgoing request scenarios:
#### Incoming Requests
For incoming requests, the generate method may optionally accept a `context` argument after the `request` argument. If the `context` parameter is specified in the method signature, it will receive the context object of the incoming request. Request handlers can:
- Check for cancellation synchronously using `context.is_stopped()` before beginning expensive operations
- Listen for cancellation asynchronously using `await context.async_killed_or_stopped()`
Example:
```python
async def generate(self, request, context):
for i in range(1000):
# Check for cancellation before expensive work
if context.is_stopped():
raise asyncio.CancelledError
# Perform work...
await expensive_computation()
yield result
```
#### Outgoing Requests
For outgoing requests, Python scripts may optionally provide a context object to outgoing runtime endpoint client router operations (such as `generate`, `round_robin`, `random`, `direct` methods) as a keyword argument. The script can cancel the outgoing request via the provided context object.
This is especially useful when child outgoing requests need to be cancelled when the parent incoming request is cancelled. In such cases, the script can simply pass the incoming context object to the outgoing request, automatically linking the cancellation behavior.
Example:
```python
async def generate(self, request, context):
# Forward the incoming context to outgoing request
# If the incoming request is cancelled, the outgoing request will be too
stream = await self.client.generate(request, context=context)
async for response in stream:
yield response
```
This design enables seamless cancellation propagation through multi-tier request chains, ensuring that when a client cancels a request, all associated sub-requests are automatically cancelled, saving computational resources across the entire request pipeline.
......@@ -137,3 +137,25 @@ class RequestHandler:
When `GeneratorExit` is raised, the frontend receives the incomplete response and can seamlessly continue generation on another available worker instance, preserving the user experience even during worker shutdowns.
For more information about how request migration works, see the [Request Migration Architecture](../architecture/request_migration.md) documentation.
## Request Cancellation
Your Python worker's request handler can optionally support request cancellation by accepting a `context` argument after the `request` argument. This context object allows you to check for cancellation signals and respond appropriately:
```python
class RequestHandler:
async def generate(self, request, context):
"""Generate response with cancellation support"""
for result in self.engine.generate_streaming(request):
# Check if the request has been cancelled
if context.is_stopped():
# Stop processing and clean up
break
yield result
```
The context parameter is optional - if your generate method doesn't include it in its signature, Dynamo will call your method without the context argument.
For detailed information about request cancellation, including async cancellation monitoring and context propagation patterns, see the [Request Cancellation Architecture](../architecture/request_cancellation.md) documentation.
......@@ -178,6 +178,12 @@ dynamo-run in=dyn://... out=<engine> ... --migration-limit=3
This allows a request to be migrated up to 3 times before failing. See the [Request Migration Architecture](../architecture/request_migration.md) documentation for details on how this works.
### Request Cancellation
When using the HTTP interface (`in=http`), if the HTTP request connection is dropped by the client, Dynamo automatically cancels the downstream request to the worker. This ensures that computational resources are not wasted on generating responses that are no longer needed.
For detailed information about how request cancellation works across the system, see the [Request Cancellation Architecture](../architecture/request_cancellation.md) documentation.
## Development
`dynamo-run` is also an example of what can be built in Rust with the `dynamo-llm` and `dynamo-runtime` crates. The following guide shows how to build from source with all the features.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
use dynamo_runtime::pipeline::context::Controller;
pub use dynamo_runtime::pipeline::AsyncEngineContext;
use pyo3::prelude::*;
use std::sync::Arc;
// PyContext is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Context is a wrapper around the AsyncEngineContext to allow for Python bindings.
// Not all methods of the AsyncEngineContext are exposed, jsut the primary ones for tracing + cancellation.
// Kept as class, to allow for future expansion if needed.
#[derive(Clone)]
#[pyclass]
pub struct PyContext {
pub inner: Arc<dyn AsyncEngineContext>,
pub struct Context {
inner: Arc<dyn AsyncEngineContext>,
}
impl PyContext {
impl Context {
pub fn new(inner: Arc<dyn AsyncEngineContext>) -> Self {
Self { inner }
}
pub fn inner(&self) -> Arc<dyn AsyncEngineContext> {
self.inner.clone()
}
}
#[pymethods]
impl PyContext {
impl Context {
#[new]
#[pyo3(signature = (id=None))]
fn py_new(id: Option<String>) -> Self {
let controller = match id {
Some(id) => Controller::new(id),
None => Controller::default(),
};
Self {
inner: Arc::new(controller),
}
}
// sync method of `await async_is_stopped()`
fn is_stopped(&self) -> bool {
self.inner.is_stopped()
......
......@@ -13,7 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use super::context::{callable_accepts_kwarg, PyContext};
use super::context::{callable_accepts_kwarg, Context};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyModule};
use pyo3::{PyAny, PyErr};
......@@ -114,7 +114,7 @@ pub struct PythonServerStreamingEngine {
_cancel_token: CancellationToken,
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
has_pycontext: bool,
has_context: bool,
}
impl PythonServerStreamingEngine {
......@@ -123,7 +123,7 @@ impl PythonServerStreamingEngine {
generator: Arc<PyObject>,
event_loop: Arc<PyObject>,
) -> Self {
let has_pycontext = Python::with_gil(|py| {
let has_context = Python::with_gil(|py| {
let callable = generator.bind(py);
callable_accepts_kwarg(py, callable, "context").unwrap_or(false)
});
......@@ -132,7 +132,7 @@ impl PythonServerStreamingEngine {
_cancel_token: cancel_token,
generator,
event_loop,
has_pycontext,
has_context,
}
}
}
......@@ -175,7 +175,7 @@ where
let generator = self.generator.clone();
let event_loop = self.event_loop.clone();
let ctx_python = ctx.clone();
let has_pycontext = self.has_pycontext;
let has_context = self.has_context;
// Acquiring the GIL is similar to acquiring a standard lock/mutex
// Performing this in an tokio async task could block the thread for an undefined amount of time
......@@ -190,9 +190,9 @@ where
let stream = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| {
let py_request = pythonize(py, &request)?;
let py_ctx = Py::new(py, PyContext::new(ctx_python.clone()))?;
let py_ctx = Py::new(py, Context::new(ctx_python.clone()))?;
let gen = if has_pycontext {
let gen = if has_context {
// Pass context as a kwarg
let kwarg = PyDict::new(py);
kwarg.set_item("context", &py_ctx)?;
......
......@@ -16,7 +16,8 @@ use tokio::sync::Mutex;
use dynamo_runtime::{
self as rs, logging,
pipeline::{
network::egress::push_router::RouterMode as RsRouterMode, EngineStream, ManyOut, SingleIn,
context::Context as RsContext, network::egress::push_router::RouterMode as RsRouterMode,
EngineStream, ManyOut, SingleIn,
},
protocols::annotated::Annotated as RsAnnotated,
traits::DistributedRuntimeProvider,
......@@ -104,7 +105,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::PyContext>()?;
m.add_class::<context::Context>()?;
m.add_class::<EtcdKvCache>()?;
m.add_class::<ModelType>()?;
m.add_class::<llm::kv::ForwardPassMetrics>()?;
......@@ -699,27 +700,29 @@ impl Client {
}
/// Issue a request to the endpoint using the default routing strategy.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn generate<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
if self.router.client.is_static() {
self.r#static(py, request, annotated)
self.r#static(py, request, annotated, context)
} else {
self.random(py, request, annotated)
self.random(py, request, annotated, context)
}
}
/// Send a request to the next endpoint in a round-robin fashion.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn round_robin<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
......@@ -728,7 +731,15 @@ impl Client {
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.round_robin(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.round_robin(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
......@@ -738,12 +749,13 @@ impl Client {
}
/// Send a request to a random endpoint.
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn random<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
......@@ -752,7 +764,15 @@ impl Client {
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.random(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.random(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.random(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
Ok(AsyncResponseStream {
rx: Arc::new(Mutex::new(rx)),
......@@ -762,13 +782,14 @@ impl Client {
}
/// Directly send a request to a specific endpoint.
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn direct<'p>(
&self,
py: Python<'p>,
request: PyObject,
instance_id: i64,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
......@@ -777,10 +798,21 @@ impl Client {
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client
.direct(request, instance_id)
.await
.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
......@@ -792,12 +824,13 @@ impl Client {
}
/// Directly send a request to a pre-defined static worker
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, annotated=DEFAULT_ANNOTATED_SETTING, context=None))]
fn r#static<'p>(
&self,
py: Python<'p>,
request: PyObject,
annotated: Option<bool>,
context: Option<context::Context>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
let annotated = annotated.unwrap_or(false);
......@@ -806,7 +839,15 @@ impl Client {
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
let stream = match context {
Some(context) => {
let request = RsContext::with_id(request, context.inner().id().to_string());
let stream = client.r#static(request).await.map_err(to_pyerr)?;
context.inner().link_child(stream.context());
stream
}
_ => client.r#static(request.into()).await.map_err(to_pyerr)?,
};
tokio::spawn(process_stream(stream, tx));
......
......@@ -25,12 +25,12 @@ from pydantic import BaseModel, ValidationError
from dynamo._core import Backend as Backend
from dynamo._core import Client as Client
from dynamo._core import Component as Component
from dynamo._core import Context as Context
from dynamo._core import DistributedRuntime as DistributedRuntime
from dynamo._core import Endpoint as Endpoint
from dynamo._core import EtcdKvCache as EtcdKvCache
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import OAIChatPreprocessor as OAIChatPreprocessor
from dynamo._core import PyContext as PyContext
def dynamo_worker(static=False):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
from time import sleep
import pytest
@pytest.fixture(scope="module", autouse=True)
def nats_and_etcd():
# Setup code
nats_server = subprocess.Popen(["nats-server", "-js"])
etcd = subprocess.Popen(["etcd"])
print("Setting up resources")
sleep(5) # wait for nats-server and etcd to start
yield
# Teardown code
print("Tearing down resources")
nats_server.terminate()
nats_server.wait()
etcd.terminate()
etcd.wait()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
import string
import pytest
from dynamo._core import DistributedRuntime
class MockServer:
"""
Test request handler that simulates a generate method with cancellation support
"""
def __init__(self):
self.context_is_stopped = False
self.context_is_killed = False
async def generate(self, request, context):
self.context_is_stopped = False
self.context_is_killed = False
method_name = request
assert hasattr(
self, method_name
), f"Method '{method_name}' not found on {self.__class__.__name__}"
method = getattr(self, method_name)
async for response in method(request, context):
yield response
async def _generate_until_context_cancelled(self, request, context):
"""
Generate method that yields numbers 0-999 every 0.1 seconds
Checks for context.is_stopped() / context.is_killed() before each yield and raises
CancelledError if stopped / killed
"""
for i in range(1000):
print(f"Processing iteration {i}")
# Check if context is stopped
if context.is_stopped():
print(f"Context stopped at iteration {i}")
self.context_is_stopped = True
self.context_is_killed = context.is_killed()
raise asyncio.CancelledError
# Check if context is killed
if context.is_killed():
print(f"Context killed at iteration {i}")
self.context_is_stopped = context.is_stopped()
self.context_is_killed = True
raise asyncio.CancelledError
await asyncio.sleep(0.1)
print(f"Sending iteration {i}")
yield i
assert (
False
), "Test failed: generate_until_cancelled did not raise CancelledError"
async def _generate_until_asyncio_cancelled(self, request, context):
"""
Generate method that yields numbers 0-999 every 0.1 seconds
"""
i = 0
try:
for i in range(1000):
print(f"Processing iteration {i}")
await asyncio.sleep(0.1)
print(f"Sending iteration {i}")
yield i
except asyncio.CancelledError:
print(f"Cancelled at iteration {i}")
self.context_is_stopped = context.is_stopped()
self.context_is_killed = context.is_killed()
raise
assert (
False
), "Test failed: generate_until_cancelled did not raise CancelledError"
async def _generate_and_cancel_context(self, request, context):
"""
Generate method that yields numbers 0-1, and then cancel the context
"""
for i in range(2):
print(f"Processing iteration {i}")
await asyncio.sleep(0.1)
print(f"Sending iteration {i}")
yield i
context.stop_generating()
self.context_is_stopped = context.is_stopped()
self.context_is_killed = context.is_killed()
async def _generate_and_raise_cancelled(self, request, context):
"""
Generate method that yields numbers 0-1, and then raise asyncio.CancelledError
"""
for i in range(2):
print(f"Processing iteration {i}")
await asyncio.sleep(0.1)
print(f"Sending iteration {i}")
yield i
raise asyncio.CancelledError
def random_string(length=10):
"""Generate a random string for namespace isolation"""
# Start with a letter to satisfy Prometheus naming requirements
first_char = random.choice(string.ascii_lowercase)
remaining_chars = string.ascii_lowercase + string.digits
rest = "".join(random.choices(remaining_chars, k=length - 1))
return first_char + rest
@pytest.fixture
async def runtime():
"""Create a DistributedRuntime for testing"""
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, True)
yield runtime
runtime.shutdown()
@pytest.fixture
def namespace():
"""Generate a random namespace for test isolation"""
return random_string()
@pytest.fixture
async def server(runtime, namespace):
"""Start a test server in the background"""
handler = MockServer()
async def init_server():
"""Initialize the test server component and serve the generate endpoint"""
component = runtime.namespace(namespace).component("backend")
await component.create_service()
endpoint = component.endpoint("generate")
print("Started test server instance")
# Serve the endpoint - this will block until shutdown
await endpoint.serve_endpoint(handler.generate)
# Start server in background task
server_task = asyncio.create_task(init_server())
# Give server time to start up
await asyncio.sleep(0.5)
yield server_task, handler
# Cleanup - cancel server task
if not server_task.done():
server_task.cancel()
try:
await server_task
except asyncio.CancelledError:
pass
@pytest.fixture
async def client(runtime, namespace):
"""Create a client connected to the test server"""
# Create client
endpoint = runtime.namespace(namespace).component("backend").endpoint("generate")
client = await endpoint.client()
await client.wait_for_instances()
return client
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import pytest
pytestmark = pytest.mark.pre_merge
def _run_test_in_subprocess(test_name: str):
"""Helper function to run a test file in a separate process"""
test_file = os.path.join(os.path.dirname(__file__), f"{test_name}.py")
result = subprocess.run(
["pytest", test_file, "-v"],
capture_output=True,
text=True,
cwd=os.path.dirname(__file__),
)
print("STDOUT:", result.stdout)
print("STDERR:", result.stderr)
print("Return code:", result.returncode)
assert (
result.returncode == 0
), f"Test {test_name} failed with return code {result.returncode}"
def test_client_context_cancel():
_run_test_in_subprocess("test_client_context_cancel")
def test_client_loop_break():
_run_test_in_subprocess("test_client_loop_break")
def test_server_context_cancel():
_run_test_in_subprocess("test_server_context_cancel")
def test_server_raise_cancelled():
_run_test_in_subprocess("test_server_raise_cancelled")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
from dynamo._core import Context
@pytest.mark.asyncio
async def test_client_context_cancel(server, client):
_, handler = server
context = Context()
stream = await client.generate("_generate_until_context_cancelled", context=context)
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
context.stop_generating()
break
iteration_count += 1
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# Verify server detected the cancellation
assert handler.context_is_stopped
assert handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pytest
@pytest.mark.asyncio
async def test_client_loop_break(server, client):
_, handler = server
stream = await client.generate("_generate_until_context_cancelled")
iteration_count = 0
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
# Verify received valid number
assert number == iteration_count
# Break after receiving 2 responses
if iteration_count >= 2:
print("Cancelling after 2 responses...")
break
iteration_count += 1
# Give server a moment to process the cancellation
await asyncio.sleep(0.2)
# TODO: Implicit cancellation is not yet implemented, so the server context will not
# show any cancellation.
assert not handler.context_is_stopped
assert not handler.context_is_killed
# TODO: Test with _generate_until_asyncio_cancelled server handler
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
@pytest.mark.asyncio
async def test_server_context_cancel(server, client):
_, handler = server
stream = await client.generate("_generate_and_cancel_context")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert str(e) == "Stream ended before generation completed"
# Verify server context cancellation status
assert handler.context_is_stopped
assert not handler.context_is_killed
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
@pytest.mark.asyncio
async def test_server_raise_cancelled(server, client):
_, handler = server
stream = await client.generate("_generate_and_raise_cancelled")
iteration_count = 0
try:
async for annotated in stream:
number = annotated.data()
print(f"Received iteration: {number}")
assert number == iteration_count
iteration_count += 1
assert False, "Stream completed without cancellation"
except ValueError as e:
# Verify the expected cancellation exception is received
# TODO: Should this be a asyncio.CancelledError?
assert (
str(e)
== "a python exception was caught while processing the async generator: CancelledError: "
)
# Verify server context cancellation status
# TODO: Server to gracefully stop the stream?
assert not handler.context_is_stopped
assert not handler.context_is_killed
......@@ -15,8 +15,6 @@
import asyncio
import subprocess
from time import sleep
from typing import List
import pytest
......@@ -37,24 +35,6 @@ from dynamo.runtime import Component, DistributedRuntime
pytestmark = pytest.mark.pre_merge
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# Setup code
nats_server = subprocess.Popen(["nats-server", "-js"])
etcd = subprocess.Popen(["etcd"])
print("Setting up resources")
sleep(5) # wait for nats-server and etcd to start
yield
# Teardown code
print("Tearing down resources")
nats_server.terminate()
nats_server.wait()
etcd.terminate()
etcd.wait()
@pytest.fixture(scope="module")
async def distributed_runtime():
loop = asyncio.get_running_loop()
......
......@@ -8,7 +8,7 @@
//! for performance analysis.
use std::pin::Pin;
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Instant;
......@@ -64,6 +64,8 @@ pub struct HttpRequestContext {
created_at: Instant,
/// Whether the request has been stopped
stopped: Arc<std::sync::atomic::AtomicBool>,
/// Child contexts to be stopped if this is stopped
child_context: Arc<Mutex<Vec<Arc<dyn AsyncEngineContext>>>>,
}
impl HttpRequestContext {
......@@ -74,6 +76,7 @@ impl HttpRequestContext {
cancel_token: CancellationToken::new(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
......@@ -84,6 +87,7 @@ impl HttpRequestContext {
cancel_token: CancellationToken::new(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
......@@ -95,6 +99,7 @@ impl HttpRequestContext {
cancel_token: self.cancel_token.child_token(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
......@@ -105,6 +110,7 @@ impl HttpRequestContext {
cancel_token: self.cancel_token.child_token(),
created_at: Instant::now(),
stopped: Arc::new(std::sync::atomic::AtomicBool::new(false)),
child_context: Arc::new(Mutex::new(Vec::new())),
}
}
......@@ -144,17 +150,55 @@ impl AsyncEngineContext for HttpRequestContext {
}
fn stop(&self) {
// Clone child Arcs to avoid deadlock if parent is accidentally linked under child
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.stop();
}
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
}
fn stop_generating(&self) {
// Clone child Arcs to avoid deadlock if parent is accidentally linked under child
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.stop_generating();
}
// For HTTP clients, stop_generating is the same as stop
self.stop();
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
}
fn kill(&self) {
// Clone child Arcs to avoid deadlock if parent is accidentally linked under child
let children = self
.child_context
.lock()
.expect("Failed to lock child context")
.iter()
.cloned()
.collect::<Vec<_>>();
for child in children {
child.kill();
}
self.stopped
.store(true, std::sync::atomic::Ordering::Release);
self.cancel_token.cancel();
......@@ -176,6 +220,13 @@ impl AsyncEngineContext for HttpRequestContext {
// For HTTP clients, killed is the same as stopped
self.cancel_token.cancelled().await;
}
fn link_child(&self, child: Arc<dyn AsyncEngineContext>) {
self.child_context
.lock()
.expect("Failed to lock child context")
.push(child);
}
}
/// Base HTTP client with common functionality
......
......@@ -17,8 +17,8 @@ use crate::{
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, ResponseStream, ServerStreamingEngine,
SingleIn, async_trait,
AsyncEngineContextProvider, Context, ManyOut, Operator, ResponseStream,
ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
......@@ -29,6 +29,11 @@ pub struct Migration {
impl Migration {
pub async fn from_mdc(mdc: ModelDeploymentCard) -> Result<Arc<Self>> {
tracing::debug!(
"model {} migration limit {}",
mdc.display_name,
mdc.migration_limit
);
Ok(Arc::new(Self {
migration_limit: mdc.migration_limit,
}))
......@@ -50,20 +55,30 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let (preprocessed_request, context) = request.transfer(());
let context_id = context.id().to_string();
let engine_ctx = context.context();
let engine_ctx_ = engine_ctx.clone();
let retry_manager =
RetryManager::build(preprocessed_request, next, self.migration_limit).await?;
let response_stream = stream::unfold(retry_manager, |mut retry_manager| async move {
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
RetryManager::build(context_id, preprocessed_request, next, self.migration_limit)
.await?;
let response_stream = stream::unfold(retry_manager, move |mut retry_manager| {
let engine_ctx = engine_ctx_.clone();
async move {
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
return None; // Stop if the context is cancelled or stopped
}
retry_manager
.next()
.await
.map(|response| (response, retry_manager))
}
});
Ok(ResponseStream::new(Box::pin(response_stream), engine_ctx))
}
}
struct RetryManager {
context_id: String,
request: PreprocessedRequest,
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
next_stream: Option<ManyOut<Annotated<LLMEngineOutput>>>,
......@@ -72,11 +87,13 @@ struct RetryManager {
impl RetryManager {
pub async fn build(
context_id: String,
preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
retries_left: u32,
) -> Result<Self> {
let mut slf = Self {
context_id,
request: preprocessed_request,
next_generate: next,
next_stream: None,
......@@ -123,8 +140,7 @@ impl RetryManager {
let mut response_stream: Option<Result<ManyOut<Annotated<LLMEngineOutput>>>> = None;
while self.retries_left > 0 {
self.retries_left -= 1;
// TODO: Is there anything needed to pass between context?
let request = SingleIn::new(self.request.clone());
let request = Context::with_id(self.request.clone(), self.context_id.clone());
response_stream = Some(self.next_generate.generate(request).await);
if let Some(err) = response_stream.as_ref().unwrap().as_ref().err()
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
......@@ -227,15 +243,22 @@ mod tests {
num_responses: usize,
token_offset: u32,
call_count: Arc<AtomicU32>,
context_id: String,
}
impl MockEngine {
fn new(behavior: MockBehavior, num_responses: usize, token_offset: u32) -> Self {
fn new(
behavior: MockBehavior,
num_responses: usize,
token_offset: u32,
context_id: String,
) -> Self {
Self {
behavior,
num_responses,
token_offset,
call_count: Arc::new(AtomicU32::new(0)),
context_id,
}
}
}
......@@ -253,7 +276,14 @@ mod tests {
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
let call_num = self.call_count.fetch_add(1, Ordering::SeqCst);
let (preprocessed_request, _) = request.transfer(());
let (preprocessed_request, context) = request.transfer(());
// Assert that the context_id matches the expected one
assert_eq!(
context.id().to_string(),
self.context_id,
"Context ID mismatch"
);
// Calculate how many responses we've already generated based on request token_ids
// Initial request has [1, 2, 3], so anything beyond that are generated responses
......@@ -328,7 +358,7 @@ mod tests {
}
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
......@@ -359,7 +389,7 @@ mod tests {
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
......@@ -395,7 +425,7 @@ mod tests {
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
......@@ -412,7 +442,7 @@ mod tests {
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
......@@ -447,7 +477,7 @@ mod tests {
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let ctx = Arc::new(Controller::default());
let ctx = Arc::new(Controller::new(self.context_id.clone()));
Ok(dynamo_runtime::pipeline::ResponseStream::new(
Box::pin(stream),
ctx,
......@@ -462,12 +492,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_no_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::Success, 10, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::Success,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 0)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 0)
.await
.expect("Failed to build RetryManager");
......@@ -494,12 +530,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_new_request_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::FailThenSuccess, 10, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::FailThenSuccess,
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
......@@ -527,16 +569,18 @@ mod tests {
async fn test_retry_manager_ongoing_request_migration() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 5 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3)
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3)
.await
.expect("Failed to build RetryManager");
......@@ -564,13 +608,19 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_new_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(0);
let mock_engine = Arc::new(MockEngine::new(MockBehavior::AlwaysFail, 0, 100));
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::AlwaysFail,
0,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
// Should fail to build due to initial stream creation failure after exhausting all 3 retries
let retry_manager_result = RetryManager::build(request, next_generate, 3).await;
let retry_manager_result = RetryManager::build(context_id, request, next_generate, 3).await;
assert!(retry_manager_result.is_err());
if let Err(error) = retry_manager_result {
......@@ -585,16 +635,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlways { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
......@@ -635,16 +687,18 @@ mod tests {
#[tokio::test]
async fn test_retry_manager_ongoing_request_migration_indefinite_failure_stream_error() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFailAlwaysStreamError { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>> =
mock_engine;
let mut retry_manager = RetryManager::build(request, next_generate, 3) // 3 retries
let mut retry_manager = RetryManager::build(context_id, request, next_generate, 3) // 3 retries
.await
.expect("Failed to build RetryManager");
......
......@@ -552,5 +552,9 @@ pub mod tests {
async fn killed(&self) {
// No-op for testing
}
fn link_child(&self, _: Arc<dyn AsyncEngineContext>) {
// No-op for testing
}
}
}
......@@ -1613,5 +1613,9 @@ mod tests {
async fn killed(&self) {
// No-op for testing
}
fn link_child(&self, _: Arc<dyn dynamo_runtime::engine::AsyncEngineContext>) {
// No-op for testing
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment