"lib/vscode:/vscode.git/clone" did not exist on "2ef408ffd3b09ca06f3bbf8c8fd44d9164b2d090"
Unverified Commit 704c1dad authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

fix: fix vllm graceful shutdown (#5818)

parent 284f772b
......@@ -25,7 +25,12 @@ class VllmEngineMonitor:
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
"""
def __init__(self, runtime: DistributedRuntime, engine_client: AsyncLLM):
def __init__(
self,
runtime: DistributedRuntime,
engine_client: AsyncLLM,
shutdown_event: asyncio.Event = None,
):
if not isinstance(runtime, DistributedRuntime):
raise ValueError(
f"{self.__class__.__name__} requires an instance of DistributedRuntime."
......@@ -37,6 +42,7 @@ class VllmEngineMonitor:
self.runtime = runtime
self.engine_client = engine_client
self.shutdown_event = shutdown_event
self._monitor_task = asyncio.create_task(self._check_engine_health())
logger.info(
......@@ -66,10 +72,41 @@ class VllmEngineMonitor:
signal.alarm(0)
async def _check_engine_health(self):
"""
Continuously check engine health until:
1. Engine dies (EngineDeadError) - initiate shutdown
2. Shutdown event is triggered - stop monitoring gracefully
3. Task is cancelled - cleanup
"""
while True:
try:
# Check if shutdown event was triggered - stop monitoring
if self.shutdown_event and self.shutdown_event.is_set():
logger.info(
f"{self.__class__.__name__}: Shutdown event detected, stopping engine health monitoring."
)
break
await self.engine_client.check_health()
# Sleep with shutdown event awareness for faster response
if self.shutdown_event:
try:
await asyncio.wait_for(
self.shutdown_event.wait(), timeout=HEALTH_CHECK_INTERVAL
)
# Shutdown event was set during sleep
logger.info(
f"{self.__class__.__name__}: Shutdown event detected, stopping engine health monitoring."
)
break
except asyncio.TimeoutError:
# Normal timeout, continue monitoring
pass
else:
# No shutdown event, just sleep normally
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
except EngineDeadError as e:
logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"vLLM AsyncLLM health check failed: {e}")
......@@ -78,4 +115,5 @@ class VllmEngineMonitor:
self.runtime.shutdown()
os._exit(1)
except asyncio.CancelledError:
pass
logger.debug(f"{self.__class__.__name__}: Health check task cancelled.")
break
......@@ -243,6 +243,7 @@ class BaseWorkerHandler(ABC):
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
):
self.runtime = runtime
self.component = component
......@@ -251,7 +252,7 @@ class BaseWorkerHandler(ABC):
self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.generate_endpoint = generate_endpoint
self.config = config
self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.engine_monitor = VllmEngineMonitor(runtime, engine, shutdown_event)
self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len
......@@ -272,6 +273,9 @@ class BaseWorkerHandler(ABC):
tokenizer = engine.tokenizer
self.input_param_manager = InputParamManager(tokenizer)
# Store shutdown event for graceful shutdown monitoring
self.shutdown_event = shutdown_event
async def sleep(self, body: dict) -> dict:
"""Sleep the engine to release GPU memory and unregister from discovery.
......@@ -339,14 +343,44 @@ class BaseWorkerHandler(ABC):
raise NotImplementedError
async def _monitor_abort(self, context, request_id, is_prefill):
"""Background task that monitors for context cancellation and aborts the request."""
"""
Background task that monitors for context cancellation and shutdown.
Aborts the request if either occurs. Raises GeneratorExit if shutdown was triggered.
"""
try:
# Build list of futures/tasks to wait for
wait_for = [context.async_killed_or_stopped()]
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
# Wait for whichever happens first
done, pending = await asyncio.wait(
wait_for,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
for task in pending:
task.cancel()
try:
await context.async_killed_or_stopped()
# If we reach here, the context was stopped or killed
await task
except asyncio.CancelledError:
pass
# Abort the request
await self.engine_client.abort(request_id)
logger.debug(
f"Aborted {'Prefill ' if is_prefill else ''}Request ID: {request_id}"
)
# Check which event triggered and raise GeneratorExit if shutdown
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during generation.")
except asyncio.CancelledError:
# Task was cancelled, normal cleanup if not aborted
pass
......@@ -355,18 +389,24 @@ class BaseWorkerHandler(ABC):
@asynccontextmanager
async def _abort_monitor(self, context, request_id, is_prefill=False):
"""Context manager that creates and automatically cleans up an abort monitoring task."""
"""
Context manager that creates and automatically cleans up an abort monitoring task.
If shutdown event was triggered, raises GeneratorExit on exit.
"""
task = asyncio.create_task(self._monitor_abort(context, request_id, is_prefill))
try:
yield task
finally:
# Cancel the abort monitoring task when exiting the context
# Clean up the abort monitoring task
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
else:
# If the task completed, check if it raised GeneratorExit
task.result()
async def clear_kv_blocks(self, request=None):
try:
......@@ -389,6 +429,20 @@ class BaseWorkerHandler(ABC):
self._lora_load_locks[lora_name] = lock
return lock
def _normalize_finish_reason(self, finish_reason: str) -> str:
"""
Normalize vLLM finish reasons to Dynamo-compatible values.
vLLM may return finish reasons that aren't recognized by Dynamo's Rust layer.
This method maps them to compatible values.
[TODO]: Remove this method and add the right code in the Rust layer.
"""
# Map vLLM's "abort" to Dynamo's "cancelled"
if finish_reason.startswith("abort"):
logging.debug(f"Normalizing finish reason: {finish_reason} to cancelled")
return "cancelled"
return finish_reason
async def load_lora(self, request=None):
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
......@@ -1112,7 +1166,6 @@ class BaseWorkerHandler(ABC):
)
num_output_tokens_so_far = 0
try:
async for res in gen:
# res is vllm's RequestOutput
......@@ -1144,10 +1197,10 @@ class BaseWorkerHandler(ABC):
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = output.finish_reason
out[
"completion_usage"
] = BaseWorkerHandler._build_completion_usage(
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=res,
embedding_sequence_length=embedding_sequence_length,
)
......@@ -1164,11 +1217,6 @@ class BaseWorkerHandler(ABC):
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
......@@ -1189,6 +1237,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
):
super().__init__(
runtime,
......@@ -1200,6 +1249,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint,
config,
use_vllm_tokenizer,
shutdown_event,
)
async def generate(self, request, context):
......@@ -1361,7 +1411,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"role": "assistant",
"content": delta_text,
},
"finish_reason": output.finish_reason,
"finish_reason": self._normalize_finish_reason(
output.finish_reason
),
}
chunk = {
......@@ -1398,6 +1450,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint=None,
config=None,
use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None,
):
super().__init__(
runtime,
......@@ -1409,6 +1462,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint,
config,
use_vllm_tokenizer,
shutdown_event,
)
async def generate(self, request, context):
......@@ -1501,7 +1555,6 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self.runtime.shutdown()
os._exit(1)
try:
async for res in gen:
logger.debug(f"kv transfer params: {res.kv_transfer_params}")
......@@ -1532,8 +1585,3 @@ class PrefillWorkerHandler(BaseWorkerHandler):
)
yield output
except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests
raise GeneratorExit(
"Prefill engine was shut down during token generation"
) from None
......@@ -61,7 +61,7 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait()
async def graceful_shutdown(runtime):
async def graceful_shutdown(runtime, shutdown_event):
"""
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
......@@ -69,6 +69,7 @@ async def graceful_shutdown(runtime):
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
......@@ -79,6 +80,9 @@ async def worker():
loop = asyncio.get_running_loop()
overwrite_args(config)
# Create shutdown event
shutdown_event = asyncio.Event()
# Set DYN_EVENT_PLANE environment variable based on config
os.environ["DYN_EVENT_PLANE"] = config.event_plane
......@@ -95,7 +99,7 @@ async def worker():
# Set up signal handler for graceful shutdown
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
......@@ -123,29 +127,29 @@ async def worker():
# Route to appropriate initialization based on config flags
if config.vllm_native_encoder_worker:
await init_vllm_native_encoder(runtime, config)
await init_vllm_native_encoder(runtime, config, shutdown_event)
logger.debug("init_vllm_native_encoder completed")
elif config.ec_processor:
await init_ec_processor(runtime, config)
await init_ec_processor(runtime, config, shutdown_event)
logger.debug("init_ec_processor completed")
elif config.multimodal_processor:
await init_multimodal_processor(runtime, config)
await init_multimodal_processor(runtime, config, shutdown_event)
logger.debug("init_multimodal_processor completed")
elif config.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
await init_multimodal_encode_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_encode_worker completed")
elif (
config.multimodal_worker
or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker
):
await init_multimodal_worker(runtime, config)
await init_multimodal_worker(runtime, config, shutdown_event)
logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker:
await init_prefill(runtime, config)
await init_prefill(runtime, config, shutdown_event)
logger.debug("init_prefill completed")
else:
await init(runtime, config)
await init(runtime, config, shutdown_event)
logger.debug("init completed")
logger.debug("Worker function completed, exiting...")
......@@ -415,7 +419,9 @@ async def register_vllm_model(
)
async def init_prefill(runtime: DistributedRuntime, config: Config):
async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Instantiate and serve
"""
......@@ -441,6 +447,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -527,7 +534,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init(runtime: DistributedRuntime, config: Config):
async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Instantiate and serve
"""
......@@ -566,6 +575,7 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint=generate_endpoint,
config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event,
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -699,7 +709,9 @@ def get_engine_cache_info(engine: AsyncLLM):
raise
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal processor component"""
component = runtime.namespace(config.namespace).component(config.component)
......@@ -754,7 +766,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
handler.cleanup()
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal encode worker component"""
component = runtime.namespace(config.namespace).component(config.component)
......@@ -792,7 +806,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
handler.cleanup()
async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
async def init_vllm_native_encoder(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize vLLM-native encoder worker component (ECConnector mode).
In this mode, vLLM handles encoder execution, caching, and storage automatically.
......@@ -853,7 +869,9 @@ async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init_ec_processor(runtime: DistributedRuntime, config: Config):
async def init_ec_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize ECConnector processor component.
......@@ -923,7 +941,9 @@ async def init_ec_processor(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""
Initialize multimodal worker component.
......@@ -983,11 +1003,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
# Choose handler based on worker type
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config
runtime, component, engine_client, config, shutdown_event
)
else:
handler = MultimodalPDWorkerHandler(
runtime, component, engine_client, config, decode_worker_client
runtime,
component,
engine_client,
config,
decode_worker_client,
shutdown_event,
)
handler.add_temp_dir(prometheus_temp_dir)
......
......@@ -37,6 +37,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
component,
engine_client,
config,
shutdown_event=None,
):
# Get default_sampling_params from config
default_sampling_params = (
......@@ -50,6 +51,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
)
self.config = config
......@@ -117,6 +119,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client: AsyncLLM,
config,
decode_worker_client: Client = None,
shutdown_event=None,
):
# Get default_sampling_params from config
default_sampling_params = (
......@@ -130,6 +133,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
shutdown_event=shutdown_event,
)
self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment