Unverified Commit ea86df29 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: Backend accept new requests during shutdown grace period (#6093)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 4ebb244b
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
import signal
from typing import Iterable, Optional
logger = logging.getLogger(__name__)
# TODO: make this using cli flag
_DEFAULT_GRACE_PERIOD_SECS = 5.0
_GRACE_PERIOD_ENV = "DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"
_shutdown_started = asyncio.Event()
def get_grace_period_seconds() -> float:
value = os.getenv(_GRACE_PERIOD_ENV)
if value is None or value == "":
return _DEFAULT_GRACE_PERIOD_SECS
try:
parsed = float(value)
except ValueError:
logger.warning(
"Invalid %s=%r; using default %s",
_GRACE_PERIOD_ENV,
value,
_DEFAULT_GRACE_PERIOD_SECS,
)
return _DEFAULT_GRACE_PERIOD_SECS
if parsed < 0:
logger.warning(
"Negative %s=%r; using 0",
_GRACE_PERIOD_ENV,
value,
)
return 0.0
return parsed
async def _unregister_endpoints(endpoints: Iterable) -> None:
seen = set()
tasks = []
for endpoint in endpoints:
endpoint_id = id(endpoint)
if endpoint_id in seen:
continue
seen.add(endpoint_id)
tasks.append(endpoint.unregister_endpoint_instance())
if not tasks:
return
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
logger.warning(
"Failed to unregister endpoint instance from discovery: %s",
result,
)
async def graceful_shutdown_with_discovery(
runtime,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
) -> None:
if _shutdown_started.is_set():
return
_shutdown_started.set()
if grace_period_s is None:
grace_period_s = get_grace_period_seconds()
logger.info("Received shutdown signal; unregistering endpoints from discovery")
await _unregister_endpoints(list(endpoints))
if grace_period_s > 0:
logger.info("Grace period %.2fs before stopping endpoints", grace_period_s)
await asyncio.sleep(grace_period_s)
if shutdown_event is not None:
shutdown_event.set()
logger.info("Initiating runtime shutdown")
runtime.shutdown()
def install_signal_handlers(
loop: asyncio.AbstractEventLoop,
runtime,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
) -> None:
shutdown_task: Optional[asyncio.Task[None]] = None
def _on_shutdown_done(task: asyncio.Task[None]) -> None:
nonlocal shutdown_task
try:
task.result()
except asyncio.CancelledError:
logger.info("Graceful shutdown task cancelled")
except Exception:
logger.exception("Graceful shutdown task failed")
finally:
if shutdown_task is task:
shutdown_task = None
def signal_handler() -> None:
nonlocal shutdown_task
if shutdown_task is not None and not shutdown_task.done():
logger.debug("Shutdown already in progress; ignoring duplicate signal")
return
shutdown_task = asyncio.create_task(
graceful_shutdown_with_discovery(
runtime,
endpoints,
shutdown_event=shutdown_event,
grace_period_s=grace_period_s,
)
)
shutdown_task.add_done_callback(_on_shutdown_done)
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logger.info(
"Signal handlers set up for graceful shutdown "
"(discovery unregister + grace period)"
)
...@@ -7,14 +7,12 @@ Common runtime utilities shared across Dynamo engine backends. ...@@ -7,14 +7,12 @@ Common runtime utilities shared across Dynamo engine backends.
Provides: Provides:
- parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings - parse_endpoint: Parse 'dyn://namespace.component.endpoint' strings
- graceful_shutdown: Shutdown DistributedRuntime with optional event signaling - graceful_shutdown: Shutdown DistributedRuntime with optional event signaling
- create_runtime: Create DistributedRuntime with signal handlers - create_runtime: Create DistributedRuntime.
""" """
import asyncio import asyncio
import logging
import os import os
import signal from typing import Tuple
from typing import Optional, Tuple
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -43,42 +41,22 @@ def parse_endpoint(endpoint: str) -> Tuple[str, str, str]: ...@@ -43,42 +41,22 @@ def parse_endpoint(endpoint: str) -> Tuple[str, str, str]:
return namespace, component, endpoint_name return namespace, component, endpoint_name
async def graceful_shutdown(
runtime: DistributedRuntime,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Shutdown DistributedRuntime with optional event signaling.
Args:
runtime: The DistributedRuntime instance to shut down.
shutdown_event: Optional event to set before shutting down,
signaling in-flight handlers to finish.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
if shutdown_event is not None:
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def create_runtime( def create_runtime(
discovery_backend: str, discovery_backend: str,
request_plane: str, request_plane: str,
event_plane: str, event_plane: str,
use_kv_events: bool, use_kv_events: bool,
shutdown_event: Optional[asyncio.Event] = None,
) -> Tuple[DistributedRuntime, asyncio.AbstractEventLoop]: ) -> Tuple[DistributedRuntime, asyncio.AbstractEventLoop]:
"""Create a DistributedRuntime and register signal handlers for graceful shutdown. """Create a DistributedRuntime.
Sets DYN_EVENT_PLANE in the environment, computes whether NATS is needed, Sets DYN_EVENT_PLANE in the environment, computes whether NATS is needed,
creates the runtime, and installs SIGTERM/SIGINT handlers. and creates the runtime.
Args: Args:
discovery_backend: Discovery backend type (kubernetes, etcd, file, mem). discovery_backend: Discovery backend type (kubernetes, etcd, file, mem).
request_plane: Request distribution method (nats, http, tcp). request_plane: Request distribution method (nats, http, tcp).
event_plane: Event publishing method (nats, zmq). event_plane: Event publishing method (nats, zmq).
use_kv_events: Whether KV events are enabled. use_kv_events: Whether KV events are enabled.
shutdown_event: Optional event to set on shutdown signal.
Returns: Returns:
Tuple of (runtime, event_loop). Tuple of (runtime, event_loop).
...@@ -91,12 +69,4 @@ def create_runtime( ...@@ -91,12 +69,4 @@ def create_runtime(
runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats) runtime = DistributedRuntime(loop, discovery_backend, request_plane, enable_nats)
def signal_handler():
asyncio.create_task(graceful_shutdown(runtime, shutdown_event))
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
logging.debug("Signal handlers set up for graceful shutdown")
return runtime, loop return runtime, loop
...@@ -18,6 +18,7 @@ from dynamo import prometheus_names ...@@ -18,6 +18,7 @@ from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
from dynamo.common.utils.runtime import create_runtime from dynamo.common.utils.runtime import create_runtime
from dynamo.llm import ModelInput, ModelType from dynamo.llm import ModelInput, ModelType
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -50,8 +51,6 @@ from dynamo.sglang.request_handlers import ( ...@@ -50,8 +51,6 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging() configure_dynamo_logging()
RUN_DEFERRED_HANDLERS: Callable[[], Awaitable[None]] | None = None
async def _handle_non_leader_node( async def _handle_non_leader_node(
engine: sgl.Engine, engine: sgl.Engine,
...@@ -95,30 +94,23 @@ SignalCallback = Callable[..., Any] ...@@ -95,30 +94,23 @@ SignalCallback = Callable[..., Any]
def install_graceful_shutdown( def install_graceful_shutdown(
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
runtime: Any, runtime: Any,
endpoints: list,
shutdown_event: asyncio.Event,
*, *,
signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT), signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT),
) -> tuple[asyncio.Event, Callable[[], Awaitable[None]]]: ) -> Callable[[], Awaitable[None]]:
""" """
Set up graceful shutdown + callback chaining. Set up graceful shutdown with discovery unregister and grace period.
What it does:
- Owns OS-level SIGTERM/SIGINT via signal.signal(...)
- Captures (suppresses) loop.add_signal_handler(SIGTERM/SIGINT, ...) registrations
and runs them during shutdown (sync or async)
- Calls runtime.shutdown() during shutdown (sync or async)
- Sets and returns an asyncio.Event you can await to know shutdown was requested
Returns: Owns OS-level SIGTERM/SIGINT via signal.signal() so SGLang's internal
(shutdown_event, run_deferred_handlers) loop.add_signal_handler registrations cannot replace our handler.
Monkey-patches loop.add_signal_handler to capture (defer) those
registrations. Returns run_deferred_handlers to be invoked in init
finally blocks (after the asyncio loop / serve_endpoint is done).
""" """
shutdown_event = asyncio.Event()
# Deferred handlers registered via loop.add_signal_handler for these signals # Deferred handlers registered via loop.add_signal_handler for these signals
deferred_handlers: DefaultDict[int, list[tuple[SignalCallback, tuple[Any, ...]]]] = defaultdict(list) # type: ignore[assignment] deferred_handlers: DefaultDict[int, list[tuple[SignalCallback, tuple[Any, ...]]]] = defaultdict(list) # type: ignore[assignment]
# Previous OS handlers (for optional chaining)
old_os_handlers: dict[int, Any] = {}
shutdown_started = False shutdown_started = False
shutdown_signum: int | None = None shutdown_signum: int | None = None
deferred_handlers_ran = False deferred_handlers_ran = False
...@@ -151,12 +143,12 @@ def install_graceful_shutdown( ...@@ -151,12 +143,12 @@ def install_graceful_shutdown(
shutdown_started = True shutdown_started = True
logging.info("Received signal %s, starting graceful shutdown", signum) logging.info("Received signal %s, starting graceful shutdown", signum)
shutdown_event.set() await graceful_shutdown_with_discovery(
runtime,
try: endpoints,
runtime.shutdown() shutdown_event=shutdown_event,
except Exception: grace_period_s=None,
logging.exception("runtime.shutdown() failed") )
def _schedule_shutdown(signum: int, frame: Any | None) -> None: def _schedule_shutdown(signum: int, frame: Any | None) -> None:
def _kick() -> None: def _kick() -> None:
...@@ -165,20 +157,17 @@ def install_graceful_shutdown( ...@@ -165,20 +157,17 @@ def install_graceful_shutdown(
loop.call_soon_threadsafe(_kick) loop.call_soon_threadsafe(_kick)
def _os_signal_handler(signum: int, frame: Any) -> None: def _os_signal_handler(signum: int, frame: Any) -> None:
# Keep the OS handler tiny; do real work in the loop thread.
_schedule_shutdown(signum, frame) _schedule_shutdown(signum, frame)
# Install OS-level handlers
for sig in signals: for sig in signals:
old_os_handlers[sig] = signal.signal(sig, _os_signal_handler) signal.signal(sig, _os_signal_handler)
# Intercept loop.add_signal_handler for SIGTERM/SIGINT and defer them
orig_add = loop.add_signal_handler orig_add = loop.add_signal_handler
def watching_add_signal_handler(sig: int, callback: SignalCallback, *args: Any): def watching_add_signal_handler(sig: int, callback: SignalCallback, *args: Any):
if sig in signals: if sig in signals:
logging.info( logging.debug(
"Captured loop.add_signal_handler(%s, %r, ...) (deferred).", "Captured underlying service trying to register for loop.add_signal_handler(%s, %r, ...).",
sig, sig,
callback, callback,
) )
...@@ -188,7 +177,7 @@ def install_graceful_shutdown( ...@@ -188,7 +177,7 @@ def install_graceful_shutdown(
loop.add_signal_handler = watching_add_signal_handler # type: ignore[assignment] loop.add_signal_handler = watching_add_signal_handler # type: ignore[assignment]
return shutdown_event, run_deferred_handlers return run_deferred_handlers
async def worker(): async def worker():
...@@ -202,6 +191,8 @@ async def worker(): ...@@ -202,6 +191,8 @@ async def worker():
config.server_args.load_format = setup_gms(config.server_args) config.server_args.load_format = setup_gms(config.server_args)
dynamo_args = config.dynamo_args dynamo_args = config.dynamo_args
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
runtime, loop = create_runtime( runtime, loop = create_runtime(
discovery_backend=dynamo_args.discovery_backend, discovery_backend=dynamo_args.discovery_backend,
request_plane=dynamo_args.request_plane, request_plane=dynamo_args.request_plane,
...@@ -209,36 +200,95 @@ async def worker(): ...@@ -209,36 +200,95 @@ async def worker():
use_kv_events=dynamo_args.use_kv_events, use_kv_events=dynamo_args.use_kv_events,
) )
# Set up signal handlers using signal module to allow chaining run_deferred_handlers = install_graceful_shutdown(
global RUN_DEFERRED_HANDLERS loop, runtime, shutdown_endpoints, shutdown_event
shutdown_event, RUN_DEFERRED_HANDLERS = install_graceful_shutdown(loop, runtime) )
logging.info("Signal handlers set up for graceful shutdown (with chaining)") logging.info(
"Signal handlers set up for graceful shutdown "
"(discovery unregister + grace period, with chaining)"
)
if config.dynamo_args.image_diffusion_worker: if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config) await init_image_diffusion(
runtime, config, shutdown_endpoints, run_deferred_handlers
)
elif config.dynamo_args.video_generation_worker: elif config.dynamo_args.video_generation_worker:
await init_video_generation(runtime, config) await init_video_generation(
runtime, config, shutdown_endpoints, run_deferred_handlers
)
elif config.dynamo_args.embedding_worker: elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config, shutdown_event) await init_embedding(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_processor: elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config, shutdown_event) await init_multimodal_processor(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_encode_worker: elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config, shutdown_event) await init_multimodal_encode_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_worker: elif config.dynamo_args.multimodal_worker:
if config.serving_mode != DisaggregationMode.PREFILL: if config.serving_mode != DisaggregationMode.PREFILL:
await init_multimodal_worker(runtime, config, shutdown_event) await init_multimodal_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
else: else:
await init_multimodal_prefill_worker(runtime, config, shutdown_event) await init_multimodal_prefill_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.diffusion_worker: elif config.dynamo_args.diffusion_worker:
await init_diffusion(runtime, config, shutdown_event) await init_diffusion(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.serving_mode != DisaggregationMode.PREFILL: elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config, shutdown_event) await init(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
else: else:
await init_prefill(runtime, config, shutdown_event) await init_prefill(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
async def init( async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -255,6 +305,7 @@ async def init( ...@@ -255,6 +305,7 @@ async def init(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -321,13 +372,17 @@ async def init( ...@@ -321,13 +372,17 @@ async def init(
logging.info("Metrics task successfully cancelled") logging.info("Metrics task successfully cancelled")
pass pass
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_prefill( async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -341,6 +396,7 @@ async def init_prefill( ...@@ -341,6 +396,7 @@ async def init_prefill(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -399,13 +455,17 @@ async def init_prefill( ...@@ -399,13 +455,17 @@ async def init_prefill(
logging.info("Metrics task successfully cancelled") logging.info("Metrics task successfully cancelled")
pass pass
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_diffusion( async def init_diffusion(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize diffusion language model worker component""" """Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -428,6 +488,7 @@ async def init_diffusion( ...@@ -428,6 +488,7 @@ async def init_diffusion(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Setup metrics and KV events for ALL nodes (including non-leader) # Setup metrics and KV events for ALL nodes (including non-leader)
# Non-leader nodes need KV event publishing for their local DP ranks # Non-leader nodes need KV event publishing for their local DP ranks
...@@ -486,13 +547,17 @@ async def init_diffusion( ...@@ -486,13 +547,17 @@ async def init_diffusion(
logging.info("Metrics task successfully cancelled") logging.info("Metrics task successfully cancelled")
pass pass
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_embedding( async def init_embedding(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize embedding worker component""" """Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -503,6 +568,7 @@ async def init_embedding( ...@@ -503,6 +568,7 @@ async def init_embedding(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# publisher instantiates the metrics and kv event publishers # publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics( publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
...@@ -550,12 +616,17 @@ async def init_embedding( ...@@ -550,12 +616,17 @@ async def init_embedding(
logging.info("Metrics task successfully cancelled") logging.info("Metrics task successfully cancelled")
pass pass
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_image_diffusion(runtime: DistributedRuntime, config: Config): async def init_image_diffusion(
runtime: DistributedRuntime,
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize image diffusion worker component""" """Initialize image diffusion worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -589,6 +660,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -589,6 +660,7 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Image diffusion doesn't have metrics publisher like LLM # Image diffusion doesn't have metrics publisher like LLM
# Could add custom metrics for images/sec, steps/sec later # Could add custom metrics for images/sec, steps/sec later
...@@ -629,12 +701,17 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config): ...@@ -629,12 +701,17 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_video_generation(runtime: DistributedRuntime, config: Config): async def init_video_generation(
runtime: DistributedRuntime,
config: Config,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
):
"""Initialize video generation worker component""" """Initialize video generation worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -668,6 +745,7 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config): ...@@ -668,6 +745,7 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
handler = VideoGenerationWorkerHandler( handler = VideoGenerationWorkerHandler(
component, component,
...@@ -704,10 +782,17 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config): ...@@ -704,10 +782,17 @@ async def init_video_generation(runtime: DistributedRuntime, config: Config):
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await run_deferred_handlers()
async def init_multimodal_processor( async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize multimodal processor component""" """Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -715,6 +800,7 @@ async def init_multimodal_processor( ...@@ -715,6 +800,7 @@ async def init_multimodal_processor(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For processor, we need to connect to the encode worker # For processor, we need to connect to the encode worker
encode_worker_client = await runtime.endpoint( encode_worker_client = await runtime.endpoint(
...@@ -754,13 +840,17 @@ async def init_multimodal_processor( ...@@ -754,13 +840,17 @@ async def init_multimodal_processor(
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_multimodal_encode_worker( async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize multimodal encode worker component""" """Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -769,6 +859,7 @@ async def init_multimodal_encode_worker( ...@@ -769,6 +859,7 @@ async def init_multimodal_encode_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# For encode worker, we need to connect to the downstream LLM worker # For encode worker, we need to connect to the downstream LLM worker
pd_worker_client = await runtime.endpoint( pd_worker_client = await runtime.endpoint(
...@@ -798,13 +889,17 @@ async def init_multimodal_encode_worker( ...@@ -798,13 +889,17 @@ async def init_multimodal_encode_worker(
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_multimodal_worker( async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize multimodal worker component. """Initialize multimodal worker component.
...@@ -818,6 +913,7 @@ async def init_multimodal_worker( ...@@ -818,6 +913,7 @@ async def init_multimodal_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
engine = sgl.Engine(server_args=server_args) engine = sgl.Engine(server_args=server_args)
...@@ -852,13 +948,17 @@ async def init_multimodal_worker( ...@@ -852,13 +948,17 @@ async def init_multimodal_worker(
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def init_multimodal_prefill_worker( async def init_multimodal_prefill_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
): ):
"""Initialize multimodal prefill worker component""" """Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
...@@ -869,6 +969,7 @@ async def init_multimodal_prefill_worker( ...@@ -869,6 +969,7 @@ async def init_multimodal_prefill_worker(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}" f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event) handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init() await handler.async_init()
...@@ -889,9 +990,9 @@ async def init_multimodal_prefill_worker( ...@@ -889,9 +990,9 @@ async def init_multimodal_prefill_worker(
raise raise
finally: finally:
handler.cleanup() handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None: if run_deferred_handlers is not None:
logging.info("Running deferred handlers") logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS() await run_deferred_handlers()
async def _warmup_prefill_engine(engine: sgl.Engine, server_args) -> None: async def _warmup_prefill_engine(engine: sgl.Engine, server_args) -> None:
......
...@@ -18,28 +18,31 @@ if "TLLM_LOG_LEVEL" not in os.environ and os.getenv( ...@@ -18,28 +18,31 @@ if "TLLM_LOG_LEVEL" not in os.environ and os.getenv(
os.environ["TLLM_LOG_LEVEL"] = tllm_level os.environ["TLLM_LOG_LEVEL"] = tllm_level
import uvloop import uvloop
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.runtime import create_runtime from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.args import parse_args from dynamo.trtllm.args import parse_args
from dynamo.trtllm.workers import init_worker from dynamo.trtllm.workers import init_worker
configure_dynamo_logging() configure_dynamo_logging()
shutdown_endpoints: list = []
async def worker(): async def worker():
config = parse_args() config = parse_args()
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
runtime, _ = create_runtime( runtime, loop = create_runtime(
discovery_backend=config.discovery_backend, discovery_backend=config.discovery_backend,
request_plane=config.request_plane, request_plane=config.request_plane,
event_plane=config.event_plane, event_plane=config.event_plane,
use_kv_events=config.use_kv_events, use_kv_events=config.use_kv_events,
shutdown_event=shutdown_event,
) )
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
logging.info(f"Initializing the worker with config: {config}") logging.info(f"Initializing the worker with config: {config}")
await init_worker(runtime, config, shutdown_event) await init_worker(runtime, config, shutdown_event, shutdown_endpoints)
def main(): def main():
......
...@@ -18,6 +18,7 @@ Note on import strategy: ...@@ -18,6 +18,7 @@ Note on import strategy:
import asyncio import asyncio
import logging import logging
from typing import Optional
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.trtllm.args import Config from dynamo.trtllm.args import Config
...@@ -26,7 +27,10 @@ from dynamo.trtllm.workers.llm_worker import init_llm_worker ...@@ -26,7 +27,10 @@ from dynamo.trtllm.workers.llm_worker import init_llm_worker
async def init_worker( async def init_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None: ) -> None:
"""Initialize the appropriate worker based on modality. """Initialize the appropriate worker based on modality.
...@@ -37,6 +41,7 @@ async def init_worker( ...@@ -37,6 +41,7 @@ async def init_worker(
runtime: The Dynamo distributed runtime. runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line. config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
""" """
logging.info(f"Initializing worker with modality={config.modality}") logging.info(f"Initializing worker with modality={config.modality}")
...@@ -48,13 +53,15 @@ async def init_worker( ...@@ -48,13 +53,15 @@ async def init_worker(
init_video_diffusion_worker, init_video_diffusion_worker,
) )
await init_video_diffusion_worker(runtime, config, shutdown_event) await init_video_diffusion_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
return return
# TODO: Add IMAGE_DIFFUSION support in follow-up PR # TODO: Add IMAGE_DIFFUSION support in follow-up PR
raise ValueError(f"Unsupported diffusion modality: {modality}") raise ValueError(f"Unsupported diffusion modality: {modality}")
# LLM modalities (text, multimodal) # LLM modalities (text, multimodal)
await init_llm_worker(runtime, config, shutdown_event) await init_llm_worker(runtime, config, shutdown_event, shutdown_endpoints)
__all__ = ["init_worker"] __all__ = ["init_worker"]
...@@ -12,6 +12,7 @@ import json ...@@ -12,6 +12,7 @@ import json
import logging import logging
import os import os
import sys import sys
from typing import Optional
from prometheus_client import REGISTRY from prometheus_client import REGISTRY
from tensorrt_llm.llmapi import ( from tensorrt_llm.llmapi import (
...@@ -109,7 +110,10 @@ def build_kv_connector_config(config: Config): ...@@ -109,7 +110,10 @@ def build_kv_connector_config(config: Config):
async def init_llm_worker( async def init_llm_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None: ) -> None:
"""Initialize and run the LLM worker. """Initialize and run the LLM worker.
...@@ -119,6 +123,7 @@ async def init_llm_worker( ...@@ -119,6 +123,7 @@ async def init_llm_worker(
runtime: The Dynamo distributed runtime. runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line. config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
""" """
encode_client = None encode_client = None
...@@ -333,6 +338,8 @@ async def init_llm_worker( ...@@ -333,6 +338,8 @@ async def init_llm_worker(
f"{config.namespace}.{config.component}.{config.endpoint}" f"{config.namespace}.{config.component}.{config.endpoint}"
) )
component = endpoint.component() component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
# should ideally call get_engine_runtime_config # should ideally call get_engine_runtime_config
# this is because we don't have a good way to # this is because we don't have a good way to
......
...@@ -9,6 +9,7 @@ workers using diffusion models (Wan, Flux, Cosmos, etc.). ...@@ -9,6 +9,7 @@ workers using diffusion models (Wan, Flux, Cosmos, etc.).
import asyncio import asyncio
import logging import logging
from typing import Optional
from dynamo.llm import ModelInput, ModelType, register_model from dynamo.llm import ModelInput, ModelType, register_model
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -16,7 +17,10 @@ from dynamo.trtllm.args import Config ...@@ -16,7 +17,10 @@ from dynamo.trtllm.args import Config
async def init_video_diffusion_worker( async def init_video_diffusion_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None,
) -> None: ) -> None:
"""Initialize and run the video diffusion worker. """Initialize and run the video diffusion worker.
...@@ -27,6 +31,7 @@ async def init_video_diffusion_worker( ...@@ -27,6 +31,7 @@ async def init_video_diffusion_worker(
runtime: The Dynamo distributed runtime. runtime: The Dynamo distributed runtime.
config: Configuration parsed from command line. config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
""" """
# Check visual_gen availability early with a clear error message. # Check visual_gen availability early with a clear error message.
# visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen # visual_gen is part of TensorRT-LLM but only available on the feat/visual_gen
...@@ -87,6 +92,8 @@ async def init_video_diffusion_worker( ...@@ -87,6 +92,8 @@ async def init_video_diffusion_worker(
f"{config.namespace}.{config.component}.{config.endpoint}" f"{config.namespace}.{config.component}.{config.endpoint}"
) )
component = endpoint.component() component = endpoint.component()
if shutdown_endpoints is not None:
shutdown_endpoints[:] = [endpoint]
# Initialize the diffusion engine (auto-detects pipeline from model_index.json) # Initialize the diffusion engine (auto-detects pipeline from model_index.json)
engine = DiffusionEngine(diffusion_config) engine = DiffusionEngine(diffusion_config)
......
...@@ -21,6 +21,7 @@ from dynamo import prometheus_names ...@@ -21,6 +21,7 @@ from dynamo import prometheus_names
from dynamo.common.config_dump import dump_config from dynamo.common.config_dump import dump_config
from dynamo.common.storage import get_fs from dynamo.common.storage import get_fs
from dynamo.common.utils.endpoint_types import parse_endpoint_types from dynamo.common.utils.endpoint_types import parse_endpoint_types
from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.output_modalities import get_output_modalities from dynamo.common.utils.output_modalities import get_output_modalities
from dynamo.common.utils.prometheus import ( from dynamo.common.utils.prometheus import (
LLMBackendMetrics, LLMBackendMetrics,
...@@ -62,6 +63,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory ...@@ -62,6 +63,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
shutdown_endpoints: list = []
CHECKPOINT_SLEEP_MODE_LEVEL = 1 CHECKPOINT_SLEEP_MODE_LEVEL = 1
...@@ -78,19 +80,6 @@ async def _handle_non_leader_node(dp_rank: int) -> None: ...@@ -78,19 +80,6 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait() await asyncio.Event().wait()
async def graceful_shutdown(runtime, shutdown_event):
"""
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
def build_headless_namespace(config: Config) -> argparse.Namespace: def build_headless_namespace(config: Config) -> argparse.Namespace:
"""Build an argparse Namespace from engine_args for vLLM's run_headless(). """Build an argparse Namespace from engine_args for vLLM's run_headless().
...@@ -172,14 +161,15 @@ async def worker(): ...@@ -172,14 +161,15 @@ async def worker():
return return
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
runtime, _ = create_runtime( runtime, loop = create_runtime(
discovery_backend=config.discovery_backend, discovery_backend=config.discovery_backend,
request_plane=config.request_plane, request_plane=config.request_plane,
event_plane=config.event_plane, event_plane=config.event_plane,
use_kv_events=config.use_kv_events, use_kv_events=config.use_kv_events,
shutdown_event=shutdown_event,
) )
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event)
# Route to appropriate initialization based on config flags # Route to appropriate initialization based on config flags
if WorkerFactory.handles(config): if WorkerFactory.handles(config):
# Create worker factory with setup functions # Create worker factory with setup functions
...@@ -189,7 +179,11 @@ async def worker(): ...@@ -189,7 +179,11 @@ async def worker():
register_vllm_model_fn=register_vllm_model, register_vllm_model_fn=register_vllm_model,
) )
await factory.create( await factory.create(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
) )
logger.debug("multimodal worker completed") logger.debug("multimodal worker completed")
elif config.omni: elif config.omni:
...@@ -653,6 +647,7 @@ async def init_prefill( ...@@ -653,6 +647,7 @@ async def init_prefill(
if config.engine_args.data_parallel_rank: if config.engine_args.data_parallel_rank:
await _handle_non_leader_node(config.engine_args.data_parallel_rank) await _handle_non_leader_node(config.engine_args.data_parallel_rank)
return return
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens model_input = ModelInput.Text if config.use_vllm_tokenizer else ModelInput.Tokens
...@@ -725,12 +720,25 @@ async def init( ...@@ -725,12 +720,25 @@ async def init(
component = generate_endpoint.component() component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
shutdown_endpoints[:] = [
generate_endpoint,
clear_endpoint,
]
lora_enabled = config.engine_args.enable_lora lora_enabled = config.engine_args.enable_lora
if lora_enabled: if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora") load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora") unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras") list_loras_endpoint = component.endpoint("list_loras")
shutdown_endpoints.extend(
[
load_lora_endpoint,
unload_lora_endpoint,
list_loras_endpoint,
]
)
model_name = config.served_model_name or config.model model_name = config.served_model_name or config.model
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
...@@ -950,6 +958,7 @@ async def init_omni( ...@@ -950,6 +958,7 @@ async def init_omni(
f"{config.namespace}.{config.component}.{config.endpoint}" f"{config.namespace}.{config.component}.{config.endpoint}"
) )
component = generate_endpoint.component() component = generate_endpoint.component()
shutdown_endpoints[:] = [generate_endpoint]
# Initialize media filesystem for storing generated images/videos # Initialize media filesystem for storing generated images/videos
media_fs = ( media_fs = (
......
...@@ -56,14 +56,22 @@ class WorkerFactory: ...@@ -56,14 +56,22 @@ class WorkerFactory:
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list,
pre_created_engine: Optional[EngineSetupResult] = None, pre_created_engine: Optional[EngineSetupResult] = None,
) -> None: ) -> None:
"""Create the appropriate multimodal worker based on config flags.""" """Create the appropriate multimodal worker based on config flags."""
if config.multimodal_encode_worker: if config.multimodal_encode_worker:
await self._create_multimodal_encode_worker(runtime, config, shutdown_event) await self._create_multimodal_encode_worker(
runtime, config, shutdown_event, shutdown_endpoints
)
elif config.multimodal_worker or config.multimodal_decode_worker: elif config.multimodal_worker or config.multimodal_decode_worker:
await self._create_multimodal_worker( await self._create_multimodal_worker(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine runtime,
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
) )
else: else:
raise ValueError( raise ValueError(
...@@ -75,6 +83,7 @@ class WorkerFactory: ...@@ -75,6 +83,7 @@ class WorkerFactory:
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
pre_created_engine: Optional[EngineSetupResult] = None, pre_created_engine: Optional[EngineSetupResult] = None,
) -> None: ) -> None:
""" """
...@@ -93,13 +102,15 @@ class WorkerFactory: ...@@ -93,13 +102,15 @@ class WorkerFactory:
) )
component = generate_endpoint.component() component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint]
lora_enabled = config.engine_args.enable_lora lora_enabled = config.engine_args.enable_lora
if lora_enabled: if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora") load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora") unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras") list_loras_endpoint = component.endpoint("list_loras")
shutdown_endpoints.extend(
[load_lora_endpoint, unload_lora_endpoint, list_loras_endpoint]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None: if pre_created_engine is not None:
( (
...@@ -226,11 +237,13 @@ class WorkerFactory: ...@@ -226,11 +237,13 @@ class WorkerFactory:
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
) -> None: ) -> None:
"""Initialize standalone multimodal encode worker.""" """Initialize standalone multimodal encode worker."""
generate_endpoint = runtime.endpoint( generate_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}" f"{config.namespace}.{config.component}.{config.endpoint}"
) )
shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler(config.engine_args) handler = EncodeWorkerHandler(config.engine_args)
await handler.async_init(runtime) await handler.async_init(runtime)
......
...@@ -12,10 +12,11 @@ This document describes how Dynamo components handle shutdown signals to ensure ...@@ -12,10 +12,11 @@ This document describes how Dynamo components handle shutdown signals to ensure
Graceful shutdown in Dynamo ensures that: Graceful shutdown in Dynamo ensures that:
1. **No new requests are accepted** - Endpoints are immediately invalidated 1. **Routing stops quickly** - Endpoints are unregistered from discovery first
2. **In-flight requests complete** - Existing requests finish processing (configurable) 2. **In-flight requests can finish** - Workers keep serving during a short grace period
3. **Resources are cleaned up** - Engines, connections, and temporary files are released 3. **Endpoints drain** - After the grace period, endpoints are invalidated and optionally wait for in-flight work
4. **Pods restart cleanly** - Exit codes signal Kubernetes for proper restart behavior 4. **Resources are cleaned up** - Engines, connections, and temporary files are released
5. **Pods restart cleanly** - Exit codes signal Kubernetes for proper restart behavior
## Signal Handling ## Signal Handling
...@@ -32,7 +33,7 @@ Each component registers signal handlers at startup: ...@@ -32,7 +33,7 @@ Each component registers signal handlers at startup:
```python ```python
def signal_handler(): def signal_handler():
asyncio.create_task(graceful_shutdown(runtime)) asyncio.create_task(graceful_shutdown(runtime, endpoints))
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, signal_handler)
...@@ -40,13 +41,15 @@ for sig in (signal.SIGTERM, signal.SIGINT): ...@@ -40,13 +41,15 @@ for sig in (signal.SIGTERM, signal.SIGINT):
The `graceful_shutdown()` function: The `graceful_shutdown()` function:
1. Logs the shutdown signal 1. Logs the shutdown signal
2. Calls `runtime.shutdown()` to invalidate endpoints 2. Unregisters all endpoints from discovery
3. Waits for in-flight requests (based on configuration) 3. Waits for a configurable grace period (`DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS`, default 5s)
4. Returns to allow cleanup to proceed 4. Calls `runtime.shutdown()` to invalidate endpoints and stop accepting new requests
5. Waits for in-flight requests (based on `graceful_shutdown` per endpoint)
6. Returns to allow cleanup to proceed
## Endpoint Draining ## Endpoint Draining
When `runtime.shutdown()` is called, endpoints are immediately invalidated so no new requests are accepted. The behavior for in-flight requests depends on the `graceful_shutdown` parameter when serving the endpoint. After the grace period, `runtime.shutdown()` invalidates endpoints so no new requests are accepted. The behavior for in-flight requests depends on the `graceful_shutdown` parameter when serving the endpoint.
### Configuration ### Configuration
......
...@@ -149,6 +149,9 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -149,6 +149,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port) env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type # Configure health check based on worker type
health_check_urls = [ health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready) (f"http://localhost:{self.system_port}/health", self.is_ready)
......
...@@ -138,6 +138,9 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -138,6 +138,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port) env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type # Configure health check based on worker type
health_check_urls = [ health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready) (f"http://localhost:{self.system_port}/health", self.is_ready)
......
...@@ -144,6 +144,9 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -144,6 +144,9 @@ class DynamoWorkerProcess(ManagedProcess):
env["DYN_SYSTEM_PORT"] = str(self.system_port) env["DYN_SYSTEM_PORT"] = str(self.system_port)
env["DYN_HTTP_PORT"] = str(frontend_port) env["DYN_HTTP_PORT"] = str(frontend_port)
# Disable backend shutdown grace period for all migration tests
env["DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"] = "0"
# Configure health check based on worker type # Configure health check based on worker type
health_check_urls = [ health_check_urls = [
(f"http://localhost:{self.system_port}/health", self.is_ready) (f"http://localhost:{self.system_port}/health", self.is_ready)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment