Unverified Commit 36c4ef5e authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: migrate requests when planner shutdown decode engine (vllm) (#2280)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent c8f6d4d9
......@@ -190,7 +190,7 @@ spec:
- /bin/sh
- -c
args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log"
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --migration-limit=3 2>&1 | tee /tmp/vllm.log"
VllmPrefillWorker:
dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret
......@@ -240,4 +240,4 @@ spec:
- /bin/sh
- -c
args:
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker --migration-limit=3 2>&1 | tee /tmp/vllm.log
......@@ -50,6 +50,7 @@ class BaseWorkerHandler(ABC):
gen = self.engine_client.generate(prompt, sampling_params, request_id)
num_output_tokens_so_far = 0
try:
async for res in gen:
# res is vllm's RequestOutput
......@@ -72,6 +73,11 @@ class BaseWorkerHandler(ABC):
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
class DecodeWorkerHandler(BaseWorkerHandler):
......@@ -173,6 +179,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
gen = self.engine_client.generate(prompt, sampling_params, request_id)
# Generate only 1 token in prefill
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
yield MyRequestOutput(
......@@ -185,3 +192,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
metrics=res.metrics,
kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
......@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
async def graceful_shutdown(runtime):
"""
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable.
However, in-flight requests will still be processed until they are finished.
After all in-flight requests are finished, the `serve_endpoint` functions will return
and the engine will be shutdown by Python's garbage collector.
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
......@@ -113,7 +113,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
......@@ -188,7 +192,9 @@ async def init(runtime: DistributedRuntime, config: Config):
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate),
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
)
except Exception as e:
......
......@@ -484,11 +484,12 @@ impl Component {
#[pymethods]
impl Endpoint {
#[pyo3(signature = (generator))]
#[pyo3(signature = (generator, graceful_shutdown = true))]
fn serve_endpoint<'p>(
&self,
py: Python<'p>,
generator: PyObject,
graceful_shutdown: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new(
generator,
......@@ -496,8 +497,13 @@ impl Endpoint {
)?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?;
builder
.graceful_shutdown(graceful_shutdown)
.start()
.await
.map_err(to_pyerr)?;
Ok(())
})
}
......
......@@ -216,10 +216,14 @@ class Endpoint:
...
async def serve_endpoint(self, handler: RequestHandler) -> None:
async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True) -> None:
"""
Serve an endpoint discoverable by all connected clients at
`{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}`
Args:
handler: The request handler function
graceful_shutdown: Whether to wait for inflight requests to complete during shutdown (default: True)
"""
...
......
......@@ -40,6 +40,10 @@ pub struct EndpointConfig {
#[educe(Debug(ignore))]
#[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>,
/// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")]
graceful_shutdown: bool,
}
impl EndpointConfigBuilder {
......@@ -55,7 +59,8 @@ impl EndpointConfigBuilder {
}
pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve();
let (endpoint, lease, handler, stats_handler, graceful_shutdown) =
self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
......@@ -109,6 +114,7 @@ impl EndpointConfigBuilder {
let push_endpoint = PushEndpoint::builder()
.service_handler(handler)
.cancellation_token(cancel_token.clone())
.graceful_shutdown(graceful_shutdown)
.build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
......
......@@ -31,6 +31,8 @@ use tokio_util::sync::CancellationToken;
pub struct PushEndpoint {
pub service_handler: Arc<dyn PushWorkHandler>,
pub cancellation_token: CancellationToken,
#[builder(default = "true")]
pub graceful_shutdown: bool,
}
/// version of crate
......@@ -116,7 +118,8 @@ impl PushEndpoint {
.unwrap()
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady);
// await for all inflight requests to complete
// await for all inflight requests to complete if graceful shutdown
if self.graceful_shutdown {
tracing::info!(
"Waiting for {} inflight requests to complete",
inflight.load(Ordering::SeqCst)
......@@ -125,6 +128,9 @@ impl PushEndpoint {
notify.notified().await;
}
tracing::info!("All inflight requests completed");
} else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
}
Ok(())
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment