Unverified Commit 36c4ef5e authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: migrate requests when planner shutdown decode engine (vllm) (#2280)


Signed-off-by: default avatarHongkuan Zhou <tedzhouhk@gmail.com>
Co-authored-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarhhzhang16 <54051230+hhzhang16@users.noreply.github.com>
parent c8f6d4d9
...@@ -190,7 +190,7 @@ spec: ...@@ -190,7 +190,7 @@ spec:
- /bin/sh - /bin/sh
- -c - -c
args: args:
- "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B 2>&1 | tee /tmp/vllm.log" - "python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --migration-limit=3 2>&1 | tee /tmp/vllm.log"
VllmPrefillWorker: VllmPrefillWorker:
dynamoNamespace: vllm-disagg-planner dynamoNamespace: vllm-disagg-planner
envFromSecret: hf-token-secret envFromSecret: hf-token-secret
...@@ -240,4 +240,4 @@ spec: ...@@ -240,4 +240,4 @@ spec:
- /bin/sh - /bin/sh
- -c - -c
args: args:
- python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker 2>&1 | tee /tmp/vllm.log - python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --is-prefill-worker --migration-limit=3 2>&1 | tee /tmp/vllm.log
...@@ -50,28 +50,34 @@ class BaseWorkerHandler(ABC): ...@@ -50,28 +50,34 @@ class BaseWorkerHandler(ABC):
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(prompt, sampling_params, request_id)
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for res in gen: try:
# res is vllm's RequestOutput async for res in gen:
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it. # This is the expected way for a request to end.
if res.finished: # The new token ID will be eos, don't forward it.
yield {"finish_reason": "stop", "token_ids": []} if res.finished:
break yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []} if not res.outputs:
break yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids) output = res.outputs[0]
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} next_total_toks = len(output.token_ids)
if output.finish_reason: out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
out["finish_reason"] = output.finish_reason if output.finish_reason:
if output.stop_reason: out["finish_reason"] = output.finish_reason
out["stop_reason"] = output.stop_reason if output.stop_reason:
yield out out["stop_reason"] = output.stop_reason
num_output_tokens_so_far = next_total_toks yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
class DecodeWorkerHandler(BaseWorkerHandler): class DecodeWorkerHandler(BaseWorkerHandler):
...@@ -173,15 +179,21 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -173,15 +179,21 @@ class PrefillWorkerHandler(BaseWorkerHandler):
gen = self.engine_client.generate(prompt, sampling_params, request_id) gen = self.engine_client.generate(prompt, sampling_params, request_id)
# Generate only 1 token in prefill # Generate only 1 token in prefill
async for res in gen: try:
logger.debug(f"kv transfer params: {res.kv_transfer_params}") async for res in gen:
yield MyRequestOutput( logger.debug(f"kv transfer params: {res.kv_transfer_params}")
request_id=res.request_id, yield MyRequestOutput(
prompt=res.prompt, request_id=res.request_id,
prompt_token_ids=res.prompt_token_ids, prompt=res.prompt,
prompt_logprobs=res.prompt_logprobs, prompt_token_ids=res.prompt_token_ids,
outputs=res.outputs, prompt_logprobs=res.prompt_logprobs,
finished=res.finished, outputs=res.outputs,
metrics=res.metrics, finished=res.finished,
kv_transfer_params=res.kv_transfer_params, metrics=res.metrics,
).model_dump_json() kv_transfer_params=res.kv_transfer_params,
).model_dump_json()
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
...@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__) ...@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
async def graceful_shutdown(runtime): async def graceful_shutdown(runtime):
""" """
By calling `runtime.shutdown()`, the endpoints will immediately be unavailable. Shutdown dynamo distributed runtime.
However, in-flight requests will still be processed until they are finished. The endpoints will be immediately invalidated so no new requests will be accepted.
After all in-flight requests are finished, the `serve_endpoint` functions will return For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
and the engine will be shutdown by Python's garbage collector. For endpoints served with graceful_shutdown=False, the serving function will return immediately.
""" """
logging.info("Received shutdown signal, shutting down DistributedRuntime") logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown() runtime.shutdown()
...@@ -113,7 +113,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -113,7 +113,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), # for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=True),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks), clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
) )
except Exception as e: except Exception as e:
...@@ -188,7 +192,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -188,7 +192,9 @@ async def init(runtime: DistributedRuntime, config: Config):
try: try:
await asyncio.gather( await asyncio.gather(
generate_endpoint.serve_endpoint(handler.generate), # for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(handler.generate, graceful_shutdown=False),
clear_endpoint.serve_endpoint(handler.clear_kv_blocks), clear_endpoint.serve_endpoint(handler.clear_kv_blocks),
) )
except Exception as e: except Exception as e:
......
...@@ -484,11 +484,12 @@ impl Component { ...@@ -484,11 +484,12 @@ impl Component {
#[pymethods] #[pymethods]
impl Endpoint { impl Endpoint {
#[pyo3(signature = (generator))] #[pyo3(signature = (generator, graceful_shutdown = true))]
fn serve_endpoint<'p>( fn serve_endpoint<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
generator: PyObject, generator: PyObject,
graceful_shutdown: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let engine = Arc::new(engine::PythonAsyncEngine::new( let engine = Arc::new(engine::PythonAsyncEngine::new(
generator, generator,
...@@ -496,8 +497,13 @@ impl Endpoint { ...@@ -496,8 +497,13 @@ impl Endpoint {
)?); )?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?; let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?;
let builder = self.inner.endpoint_builder().handler(ingress); let builder = self.inner.endpoint_builder().handler(ingress);
let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder.start().await.map_err(to_pyerr)?; builder
.graceful_shutdown(graceful_shutdown)
.start()
.await
.map_err(to_pyerr)?;
Ok(()) Ok(())
}) })
} }
......
...@@ -216,10 +216,14 @@ class Endpoint: ...@@ -216,10 +216,14 @@ class Endpoint:
... ...
async def serve_endpoint(self, handler: RequestHandler) -> None: async def serve_endpoint(self, handler: RequestHandler, graceful_shutdown: bool = True) -> None:
""" """
Serve an endpoint discoverable by all connected clients at Serve an endpoint discoverable by all connected clients at
`{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}` `{{ namespace }}/components/{{ component_name }}/endpoints/{{ endpoint_name }}`
Args:
handler: The request handler function
graceful_shutdown: Whether to wait for inflight requests to complete during shutdown (default: True)
""" """
... ...
......
...@@ -40,6 +40,10 @@ pub struct EndpointConfig { ...@@ -40,6 +40,10 @@ pub struct EndpointConfig {
#[educe(Debug(ignore))] #[educe(Debug(ignore))]
#[builder(default, private)] #[builder(default, private)]
_stats_handler: Option<EndpointStatsHandler>, _stats_handler: Option<EndpointStatsHandler>,
/// Whether to wait for inflight requests to complete during shutdown
#[builder(default = "true")]
graceful_shutdown: bool,
} }
impl EndpointConfigBuilder { impl EndpointConfigBuilder {
...@@ -55,7 +59,8 @@ impl EndpointConfigBuilder { ...@@ -55,7 +59,8 @@ impl EndpointConfigBuilder {
} }
pub async fn start(self) -> Result<()> { pub async fn start(self) -> Result<()> {
let (endpoint, lease, handler, stats_handler) = self.build_internal()?.dissolve(); let (endpoint, lease, handler, stats_handler, graceful_shutdown) =
self.build_internal()?.dissolve();
let lease = lease.or(endpoint.drt().primary_lease()); let lease = lease.or(endpoint.drt().primary_lease());
let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0); let lease_id = lease.as_ref().map(|l| l.id()).unwrap_or(0);
...@@ -109,6 +114,7 @@ impl EndpointConfigBuilder { ...@@ -109,6 +114,7 @@ impl EndpointConfigBuilder {
let push_endpoint = PushEndpoint::builder() let push_endpoint = PushEndpoint::builder()
.service_handler(handler) .service_handler(handler)
.cancellation_token(cancel_token.clone()) .cancellation_token(cancel_token.clone())
.graceful_shutdown(graceful_shutdown)
.build() .build()
.map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?; .map_err(|e| anyhow::anyhow!("Failed to build push endpoint: {e}"))?;
......
...@@ -31,6 +31,8 @@ use tokio_util::sync::CancellationToken; ...@@ -31,6 +31,8 @@ use tokio_util::sync::CancellationToken;
pub struct PushEndpoint { pub struct PushEndpoint {
pub service_handler: Arc<dyn PushWorkHandler>, pub service_handler: Arc<dyn PushWorkHandler>,
pub cancellation_token: CancellationToken, pub cancellation_token: CancellationToken,
#[builder(default = "true")]
pub graceful_shutdown: bool,
} }
/// version of crate /// version of crate
...@@ -116,15 +118,19 @@ impl PushEndpoint { ...@@ -116,15 +118,19 @@ impl PushEndpoint {
.unwrap() .unwrap()
.set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady); .set_endpoint_health_status(endpoint_name.clone(), HealthStatus::NotReady);
// await for all inflight requests to complete // await for all inflight requests to complete if graceful shutdown
tracing::info!( if self.graceful_shutdown {
"Waiting for {} inflight requests to complete", tracing::info!(
inflight.load(Ordering::SeqCst) "Waiting for {} inflight requests to complete",
); inflight.load(Ordering::SeqCst)
while inflight.load(Ordering::SeqCst) > 0 { );
notify.notified().await; while inflight.load(Ordering::SeqCst) > 0 {
notify.notified().await;
}
tracing::info!("All inflight requests completed");
} else {
tracing::info!("Skipping graceful shutdown, not waiting for inflight requests");
} }
tracing::info!("All inflight requests completed");
Ok(()) Ok(())
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment