Unverified Commit 347620a1 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Allow Python Engine to end stream before final (#2270)

parent b48d4c3b
......@@ -109,3 +109,31 @@ Example 4: Multiple component in a pipeline.
In the P/D disaggregated setup you would have `deepseek-distill-llama8b.prefill.generate` (possibly multiple instances of this) and `deepseek-distill-llama8b.decode.generate`.
## Migrate Ongoing Requests
A Python worker may need to be shut down promptly, for example when the node running the worker is to be reclaimed and there isn't enough time to complete all ongoing requests before the shutdown deadline.
In such cases, you can signal incomplete responses by raising a `GeneratorExit` exception in your generate loop. This will immediately close the response stream, signaling to the frontend that the stream is incomplete. With request migration enabled (see the [`migration_limit`](../architecture/request_migration.md) parameter), the frontend will automatically migrate the partially completed request to another worker instance, if available, to be completed.
> [!WARNING]
> We will update the `GeneratorExit` exception to a new Dynamo exception. Please expect minor code breaking change in the near future.
Here's an example of how to implement this in your `RequestHandler`:
```python
class RequestHandler:
async def generate(self, request):
"""Generate response, with support for request migration"""
for result in self.engine.generate_streaming(request):
# Check if we need to migrate before yielding each token
if is_shutting_down():
# Raising GeneratorExit closes the stream and triggers migration
raise GeneratorExit("Worker shutting down, migrating request")
yield result
```
When `GeneratorExit` is raised, the frontend receives the incomplete response and can seamlessly continue generation on another available worker instance, preserving the user experience even during worker shutdowns.
For more information about how request migration works, see the [Request Migration Architecture](../architecture/request_migration.md) documentation.
......@@ -134,6 +134,9 @@ enum ResponseProcessingError {
#[error("python exception: {0}")]
PythonException(String),
#[error("python generator exit: {0}")]
PyGeneratorExit(String),
#[error("deserialize error: {0}")]
DeserializeError(String),
......@@ -225,6 +228,9 @@ where
let msg = format!("critical error: invalid response object from python async generator; application-logic-mismatch: {}", e);
msg
}
ResponseProcessingError::PyGeneratorExit(_) => {
"Stream ended before generation completed".to_string()
}
ResponseProcessingError::PythonException(e) => {
let msg = format!("a python exception was caught while processing the async generator: {}", e);
msg
......@@ -276,8 +282,16 @@ where
{
let item = item.map_err(|e| {
println!();
Python::with_gil(|py| e.display(py));
ResponseProcessingError::PythonException(e.to_string())
let mut is_py_generator_exit = false;
Python::with_gil(|py| {
e.display(py);
is_py_generator_exit = e.is_instance_of::<pyo3::exceptions::PyGeneratorExit>(py);
});
if is_py_generator_exit {
ResponseProcessingError::PyGeneratorExit(e.to_string())
} else {
ResponseProcessingError::PythonException(e.to_string())
}
})?;
let response = tokio::task::spawn_blocking(move || {
Python::with_gil(|py| depythonize::<Resp>(&item.into_bound(py)))
......
......@@ -14,6 +14,7 @@
// limitations under the License.
use super::*;
use crate::protocols::maybe_error::MaybeError;
use prometheus::{Histogram, IntCounter, IntCounterVec, IntGauge};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
......@@ -105,7 +106,7 @@ impl WorkHandlerMetrics {
impl<T: Data, U: Data> PushWorkHandler for Ingress<SingleIn<T>, ManyOut<U>>
where
T: Data + for<'de> Deserialize<'de> + std::fmt::Debug,
U: Data + Serialize + std::fmt::Debug,
U: Data + Serialize + MaybeError + std::fmt::Debug,
{
fn add_metrics(&self, endpoint: &crate::component::Endpoint) -> Result<()> {
// Call the Ingress-specific add_metrics implementation
......@@ -220,6 +221,14 @@ where
let mut send_complete_final = true;
while let Some(resp) = stream.next().await {
tracing::trace!("Sending response: {:?}", resp);
if let Some(err) = resp.err() {
const STREAM_ERR_MSG: &str = "Stream ended before generation completed";
if format!("{:?}", err) == STREAM_ERR_MSG {
tracing::warn!(STREAM_ERR_MSG);
send_complete_final = false;
break;
}
}
let resp_wrapper = NetworkStreamWrapper {
data: Some(resp),
complete_final: false,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment