Unverified Commit ea86df29 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: Backend accept new requests during shutdown grace period (#6093)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 4ebb244b
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
import signal
from typing import Iterable, Optional
logger = logging.getLogger(__name__)
# TODO: make this using cli flag
_DEFAULT_GRACE_PERIOD_SECS = 5.0
_GRACE_PERIOD_ENV = "DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"
_shutdown_started = asyncio.Event()
def get_grace_period_seconds() -> float:
value = os.getenv(_GRACE_PERIOD_ENV)
if value is None or value == "":
return _DEFAULT_GRACE_PERIOD_SECS
try:
parsed = float(value)
except ValueError:
logger.warning(
"Invalid %s=%r; using default %s",
_GRACE_PERIOD_ENV,
value,
_DEFAULT_GRACE_PERIOD_SECS,
)
return _DEFAULT_GRACE_PERIOD_SECS
if parsed < 0:
logger.warning(
"Negative %s=%r; using 0",
_GRACE_PERIOD_ENV,
value,
)
return 0.0
return parsed
async def _unregister_endpoints(endpoints: Iterable) -> None:
seen = set()
tasks = []
for endpoint in endpoints:
endpoint_id = id(endpoint)
if endpoint_id in seen:
continue
seen.add(endpoint_id)
tasks.append(endpoint.unregister_endpoint_instance())
if not tasks:
return
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.warning(
"Failed to unregister endpoint instance from discovery: %s",
result,
)
async def graceful_shutdown_with_discovery(
runtime,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
) -> None:
if _shutdown_started.is_set():
return
_shutdown_started.set()
if grace_period_s is None:
grace_period_s = get_grace_period_seconds()
logger.info("Received shutdown signal; unregistering endpoints from discovery")
await _unregister_endpoints(list(endpoints))
if grace_period_s > 0:
logger.info("Grace period %.2fs before stopping endpoints", grace_period_s)
await asyncio.sleep(grace_period_s)
if shutdown_event is not None:
shutdown_event.set()
logger.info("Initiating runtime shutdown")
runtime.shutdown()
def install_signal_handlers(
loop: asyncio.AbstractEventLoop,
runtime,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
) -> None:
shutdown_task: Optional[asyncio.Task[None]] = None
def _on_shutdown_done(task: asyncio.Task[None]) -> None:
nonlocal shutdown_task
try:
task.result()
except asyncio.CancelledError:
logger.info("Graceful shutdown task cancelled")
except Exception:
logger.exception("Graceful shutdown task failed")
finally:
if shutdown_task is task:
shutdown_task = None
def signal_handler() -> None:
nonlocal shutdown_task
if shutdown_task is not None and not shutdown_task.done():
logger.debug("Shutdown already in progress; ignoring duplicate signal")
return
shutdown_task = asyncio.create_task(
graceful_shutdown_with_discovery(
runtime,
endpoints,
shutdown_event=shutdown_event,
grace_period_s=grace_period_s,
)
)
shutdown_task.add_done_callback(_on_shutdown_done)
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info(
"Signal handlers set up for graceful shutdown "
"(discovery unregister + grace period)"
)
......@@ -7,14 +7,12 @@ Common runtime utilities shared across Dynamo engine backends.
Provides:
- parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings
- graceful_shutdown: Shutdown DistributedRuntime with optional event signaling
- create_runtime: Create DistributedRuntime with signal handlers
- create_runtime: Create DistributedRuntime.
"""
import asyncio
import logging
import os
import signal
from typing import Optional, Tuple
from typing import Tuple
from dynamo.runtime import DistributedRuntime
......@@ -43,42 +41,22 @@ def parse_endpoint(endpoint: str) -> Tuple[str, str, str]:
return namespace, component, endpoint_name
async def graceful_shutdown(
runtime: DistributedRuntime,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Shutdown DistributedRuntime with optional event signaling.
Args:
runtime: The DistributedRuntime instance to shut down.
shutdown_event: Optional event to set before shutting down,
signaling in-flight handlers to finish.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
if shutdown_event is not None:
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def create_runtime(
discovery_backend: str,
request_plane: str,
event_plane: str,
use_kv_events: bool,
shutdown_event: Optional[asyncio.Event] = None,
) -> Tuple[DistributedRuntime, asyncio.AbstractEventLoop]:
"""Create a DistributedRuntime and register signal handlers for graceful shutdown.
"""Create a DistributedRuntime.
Sets DYN_EVENT_PLANE in the environment, computes whether NATS is needed,
creates the runtime, and installs SIGTERM/SIGINT handlers.
and creates the runtime.
Args:
discovery_backend: Discovery backend type (kubernetes, etcd, file, mem).
request_plane: Request distribution method (nats, http, tcp).
event_plane: Event publishing method (nats, zmq).
use_kv_events: Whether KV events are enabled.
shutdown_event: Optional event to set on shutdown signal.
Returns:
Tuple of (runtime, event_loop).
......@@ -91,12 +69,4 @@ def create_runtime(
runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats)
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.debug("Signal handlers set up for graceful shutdown")
return runtime, loop
......@@ -18,6 +18,7 @@ from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
from dynamo.common.utils.runtime import create_runtime
from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime
......@@ -50,8 +51,6 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging()
RUN_DEFERRED_HANDLERS: Callable[[], Awaitable[None]] | None = None
async def _handle_non_leader_node(
engine: sgl.Engine,
......@@ -95,30 +94,23 @@ SignalCallback = Callable[..., Any]
def install_graceful_shutdown(
loop: asyncio.AbstractEventLoop,
runtime: Any,
endpoints: list,
shutdown_event: asyncio.Event,
*,
signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT),
) -> tuple[asyncio.Event, Callable[[], Awaitable[None]]]:
) -> Callable[[], Awaitable[None]]:
"""
Set up graceful shutdown + callback chaining.
What it does:
- Owns OS-level SIGTERM/SIGINT via signal.signal(...)
- Captures (suppresses) loop.add_signal_handler(SIGTERM/SIGINT, ...) registrations
and runs them during shutdown (sync or async)
- Calls runtime.shutdown() during shutdown (sync or async)
- Sets and returns an asyncio.Event you can await to know shutdown was requested
Set up graceful shutdown with discovery unregister and grace period.
Returns:
(shutdown_event, run_deferred_handlers)
Owns OS-level SIGTERM/SIGINT via signal.signal() so SGLang's internal
loop.add_signal_handler registrations cannot replace our handler.
Monkey-patches loop.add_signal_handler to capture (defer) those
registrations. Returns run_deferred_handlers to be invoked in init
finally blocks (after the asyncio loop / serve_endpoint is done).
"""
shutdown_event = asyncio.Event()
# Deferred handlers registered via loop.add_signal_handler for these signals
deferred_handlers: DefaultDict[int, list[tuple[SignalCallback, tuple[Any, ...]]]] = defaultdict(list) # type: ignore[assignment]
# Previous OS handlers (for optional chaining)
old_os_handlers: dict[int, Any] = {}
shutdown_started = False
shutdown_signum: int | None = None
deferred_handlers_ran = False
......@@ -151,12 +143,12 @@ def install_graceful_shutdown(
shutdown_started = True
logging.info("Received signal %s, starting graceful shutdown", signum)
shutdown_event.set()
try:
runtime.shutdown()
except Exception:
logging.exception("runtime.shutdown() failed")
await graceful_shutdown_with_discovery(
runtime,
endpoints,
shutdown_event=shutdown_event,
grace_period_s=None,
)
def _schedule_shutdown(signum: int, frame: Any | None) -> None:
def _kick() -> None:
......@@ -165,20 +157,17 @@ def install_graceful_shutdown(
loop.call_soon_threadsafe(_kick)
def _os_signal_handler(signum: int, frame: Any) -> None:
# Keep the OS handler tiny; do real work in the loop thread.
_schedule_shutdown(signum, frame)
# Install OS-level handlers
for sig in signals:
old_os_handlers[sig] = signal.signal(sig, _os_signal_handler)
signal.signal(sig, _os_signal_handler)
# Intercept loop.add_signal_handler for SIGTERM/SIGINT and defer them
orig_add = loop.add_signal_handler
def watching_add_signal_handler(sig: int, callback: SignalCallback, *args: Any):
if sig in signals:
logging.info(
"Captured loop.add_signal_handler(%s, %r, ...) (deferred).",
logging.debug(
"Captured underlying service trying to register for loop.add_signal_handler(%s, %r, ...).",
sig,
callback,
)
......@@ -188,7 +177,7 @@ def install_graceful_shutdown(
loop.add_signal_handler = watching_add_signal_handler # type: ignore[assignment]
return shutdown_event, run_deferred_handlers
return run_deferred_handlers
async def worker():
......@@ -202,6 +191,8 @@ async def worker():
config.server_args.load_format = setup_gms(config.server_args)
dynamo_args = config.dynamo_args
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
runtime, loop = create_runtime(
discovery_backend=dynamo_args.discovery_backend,
request_plane=dynamo_args.request_plane,
......@@ -209,36 +200,95 @@ async def worker():
use_kv_events=dynamo_args.use_kv_events,
)
# Set up signal handlers using signal module to allow chaining
global RUN_DEFERRED_HANDLERS
shutdown_event, RUN_DEFERRED_HANDLERS = install_graceful_shutdown(loop, runtime)
logging.info("Signal handlers set up for graceful shutdown (with chaining)")
run_deferred_handlers = install_graceful_shutdown(
loop, runtime, shutdown_endpoints, shutdown_event
)
logging.info(
"Signal handlers set up for graceful shutdown "
"(discovery unregister + grace period, with chaining)"
)
if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config)
await init_image_diffusion(
runtime, config, shutdown_endpoints, run_deferred_handlers
)
elif config.dynamo_args.video_generation_worker:
await init_video_generation(runtime, config)
await init_video_generation(
runtime, config, shutdown_endpoints, run_deferred_handlers
)
elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config, shutdown_event)
await init_embedding(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config, shutdown_event)
await init_multimodal_processor(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config, shutdown_event)
await init_multimodal_encode_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_worker:
if config.serving_mode != DisaggregationMode.PREFILL:
await init_multimodal_worker(runtime, config, shutdown_event)
await init_multimodal_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
else:
await init_multimodal_prefill_worker(runtime, config, shutdown_event)
await init_multimodal_prefill_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.diffusion_worker:
await init_diffusion(runtime, config, shutdown_event)
await init_diffusion(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config, shutdown_event)
await init(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
else:
await init_prefill(runtime, config, shutdown_event)
await init_prefill(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -255,6 +305,7 @@ async def init(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -321,13 +372,17 @@ async def init(
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -341,6 +396,7 @@ async def init_prefill(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -399,13 +455,17 @@ async def init_prefill(
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_diffusion(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -428,6 +488,7 @@ async def init_diffusion(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks
......@@ -486,13 +547,17 @@ async def init_diffusion(
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_embedding(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -503,6 +568,7 @@ async def init_embedding(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
......@@ -550,12 +616,17 @@ async def init_embedding(
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
async def init_image_diffusion(
runtime: DistributedRuntime,
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -589,6 +660,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later
......@@ -629,12 +701,17 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_video_generation(runtime: DistributedRuntime, config: Config):
async def init_video_generation(
runtime: DistributedRuntime,
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize video generation worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -668,6 +745,7 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
handler = VideoGenerationWorkerHandler(
component,
......@@ -704,10 +782,17 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
raise
finally:
handler.cleanup()
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await run_deferred_handlers()
async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -715,6 +800,7 @@ async def init_multimodal_processor(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For processor, we need to connect to the encode worker
encode_worker_client = await runtime.endpoint(
......@@ -754,13 +840,17 @@ async def init_multimodal_processor(
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -769,6 +859,7 @@ async def init_multimodal_encode_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = await runtime.endpoint(
......@@ -798,13 +889,17 @@ async def init_multimodal_encode_worker(
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize multimodal worker component.
......@@ -818,6 +913,7 @@ async def init_multimodal_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
engine = sgl.Engine(server_args=server_args)
......@@ -852,13 +948,17 @@ async def init_multimodal_worker(
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def init_multimodal_prefill_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -869,6 +969,7 @@ async def init_multimodal_prefill_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init()
......@@ -889,9 +990,9 @@ async def init_multimodal_prefill_worker(
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
await run_deferred_handlers()
async def _warmup_prefill_engine(engine: sgl.Engine, server_args) -> None:
......
......@@ -18,28 +18,31 @@ if "TLLM_LOG_LEVEL" not in os.environ and os.getenv(
os.environ["TLLM_LOG_LEVEL"] = tllm_level
import uvloop
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.args import parse_args
from dynamo.trtllm.workers import init_worker
configure_dynamo_logging()
shutdown_endpoints: list = []
async def worker():
config = parse_args()
shutdown_event = asyncio.Event()
runtime, _ = create_runtime(
runtime, loop = create_runtime(
discovery_backend=config.discovery_backend,
request_plane=config.request_plane,
event_plane=config.event_plane,
use_kv_events=config.use_kv_events,
shutdown_event=shutdown_event,
)
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
logging.info(f"Initializing the worker with config: {config}")
await init_worker(runtime, config, shutdown_event)
await init_worker(runtime, config, shutdown_event, shutdown_endpoints)
def main():
......
......@@ -18,6 +18,7 @@ Note on import strategy:
import asyncio
import logging
from typing import Optional
from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config
......@@ -26,7 +27,10 @@ from dynamo.trtllm.workers.llm_worker import init_llm_worker
async def init_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None:
"""Initialize the appropriate worker based on modality.
......@@ -37,6 +41,7 @@ async def init_worker(
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
logging.info(f"Initializing worker with modality={config.modality}")
......@@ -48,13 +53,15 @@ async def init_worker(
init_video_diffusion_worker,
)
await init_video_diffusion_worker(runtime, config, shutdown_event)
await init_video_diffusion_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
return
# TODO: Add IMAGE_DIFFUSION support in follow-up PR
raise ValueError(f"Unsupported diffusion modality: {modality}")
# LLM modalities (text, multimodal)
await init_llm_worker(runtime, config, shutdown_event)
await init_llm_worker(runtime, config, shutdown_event, shutdown_endpoints)
__all__ = ["init_worker"]
......@@ -12,6 +12,7 @@ import json
import logging
import os
import sys
from typing import Optional
from prometheus_client import REGISTRY
from tensorrt_llm.llmapi import (
......@@ -109,7 +110,10 @@ def build_kv_connector_config(config: Config):
async def init_llm_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None:
"""Initialize and run the LLM worker.
......@@ -119,6 +123,7 @@ async def init_llm_worker(
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
encode_client = None
......@@ -333,6 +338,8 @@ async def init_llm_worker(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
# should ideally call get_engine_runtime_config
# this is because we don't have a good way to
......
......@@ -9,6 +9,7 @@ workers using diffusion models (Wan, Flux, Cosmos, etc.).
import asyncio
import logging
from typing import Optional
from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime
......@@ -16,7 +17,10 @@ from dynamo.trtllm.args import Config
async def init_video_diffusion_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None:
"""Initialize and run the video diffusion worker.
......@@ -27,6 +31,7 @@ async def init_video_diffusion_worker(
runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
"""
# Check visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen
......@@ -87,6 +92,8 @@ async def init_video_diffusion_worker(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
# Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config)
......
......@@ -21,6 +21,7 @@ from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.common.utils.prometheus import (
LLMBackendMetrics,
......@@ -62,6 +63,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
configure_dynamo_logging()
logger = logging.getLogger(__name__)
shutdown_endpoints: list = []
CHECKPOINT_SLEEP_MODE_LEVEL = 1
......@@ -78,19 +80,6 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait()
async def graceful_shutdown(runtime, shutdown_event):
"""
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def build_headless_namespace(config: Config) -> argparse.Namespace:
"""Build an argparse Namespace from engine_args for vLLM's run_headless().
......@@ -172,14 +161,15 @@ async def worker():
return
shutdown_event = asyncio.Event()
runtime, _ = create_runtime(
runtime, loop = create_runtime(
discovery_backend=config.discovery_backend,
request_plane=config.request_plane,
event_plane=config.event_plane,
use_kv_events=config.use_kv_events,
shutdown_event=shutdown_event,
)
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags
if WorkerFactory.handles(config):
# Create worker factory with setup functions
......@@ -189,7 +179,11 @@ async def worker():
register_vllm_model_fn=register_vllm_model,
)
await factory.create(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
)
logger.debug("multimodal worker completed")
elif config.omni:
......@@ -653,6 +647,7 @@ async def init_prefill(
if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
......@@ -725,12 +720,25 @@ async def init(
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
shutdown_endpoints[:] = [
generate_endpoint,
clear_endpoint,
]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
shutdown_endpoints.extend(
[
load_lora_endpoint,
unload_lora_endpoint,
list_loras_endpoint,
]
)
model_name = config.served_model_name or config.model
# Use pre-created engine if provided (checkpoint mode), otherwise create new
......@@ -950,6 +958,7 @@ async def init_omni(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Initialize media filesystem for storing generated images/videos
media_fs = (
......
......@@ -56,14 +56,22 @@ class WorkerFactory:
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
pre_created_engine: Optional[EngineSetupResult] = None,
) -> None:
"""Create the appropriate multimodal worker based on config flags."""
if config.multimodal_encode_worker:
await self._create_multimodal_encode_worker(runtime, config, shutdown_event)
await self._create_multimodal_encode_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
elif config.multimodal_worker or config.multimodal_decode_worker:
await self._create_multimodal_worker(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
)
else:
raise ValueError(
......@@ -75,6 +83,7 @@ class WorkerFactory:
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
pre_created_engine: Optional[EngineSetupResult] = None,
) -> None:
"""
......@@ -93,13 +102,15 @@ class WorkerFactory:
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
shutdown_endpoints.extend(
[load_lora_endpoint, unload_lora_endpoint, list_loras_endpoint]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
......@@ -226,11 +237,13 @@ class WorkerFactory:
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
) -> None:
"""Initialize standalone multimodal encode worker."""
generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}"
)
shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler(config.engine_args)
await handler.async_init(runtime)
......
......@@ -12,10 +12,11 @@ This document describes how Dynamo components handle shutdown signals to ensure
Graceful shutdown in Dynamo ensures that:
1. **No new requests are accepted** - Endpoints are immediately invalidated
2. **In-flight requests complete** - Existing requests finish processing (configurable)
3. **Resources are cleaned up** - Engines, connections, and temporary files are released
4. **Pods restart cleanly** - Exit codes signal Kubernetes for proper restart behavior
1. **Routing stops quickly** - Endpoints are unregistered from discovery first
2. **In-flight requests can finish** - Workers keep serving during a short grace period
3. **Endpoints drain** - After the grace period, endpoints are invalidated and optionally wait for in-flight work
4. **Resources are cleaned up** - Engines, connections, and temporary files are released
5. **Pods restart cleanly** - Exit codes signal Kubernetes for proper restart behavior
## Signal Handling
......@@ -32,7 +33,7 @@ Each component registers signal handlers at startup:
```python
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime))
asyncio.create_task(graceful_shutdown(runtime, endpoints))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
......@@ -40,13 +41,15 @@ for sig in (signal.SIGTERM, signal.SIGINT):
The `graceful_shutdown()` function:
1. Logs the shutdown signal
2. Calls `runtime.shutdown()` to invalidate endpoints
3. Waits for in-flight requests (based on configuration)
4. Returns to allow cleanup to proceed
2. Unregisters all endpoints from discovery
3. Waits for a configurable grace period (`DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS`, default 5s)
4. Calls `runtime.shutdown()` to invalidate endpoints and stop accepting new requests
5. Waits for in-flight requests (based on `graceful_shutdown` per endpoint)
6. Returns to allow cleanup to proceed
## Endpoint Draining
When `runtime.shutdown()` is called, endpoints are immediately invalidated so no new requests are accepted. The behavior for in-flight requests depends on the `graceful_shutdown` parameter when serving the endpoint.
After the grace period, `runtime.shutdown()` invalidates endpoints so no new requests are accepted. The behavior for in-flight requests depends on the `graceful_shutdown` parameter when serving the endpoint.
### Configuration
......
......@@ -149,6 +149,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
......
......@@ -138,6 +138,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
......
......@@ -144,6 +144,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type
health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment