Unverified Commit b94f9dcd authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: request migration for SGLang (#5659)

parent 2ec1c3f5
......@@ -14,6 +14,20 @@ Submodules:
- prometheus: Prometheus metrics collection and logging utilities
"""
from dynamo.common.utils import endpoint_types, otel_tracing, paths, prometheus, runtime
from dynamo.common.utils import (
endpoint_types,
engine_response,
otel_tracing,
paths,
prometheus,
runtime,
)
__all__ = ["endpoint_types", "otel_tracing", "paths", "prometheus", "runtime"]
__all__ = [
"endpoint_types",
"engine_response",
"otel_tracing",
"paths",
"prometheus",
"runtime",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Utilities for engine response processing."""
import logging
def normalize_finish_reason(finish_reason: str) -> str:
"""
Normalize engine finish reasons to Dynamo-compatible values.
Engine may return finish reasons that aren't recognized by Dynamo's Rust layer.
This method maps them to compatible values.
[TODO]: Remove this method and add the right code in the Rust layer.
"""
# Map engine's "abort" to Dynamo's "cancelled"
if finish_reason and finish_reason.startswith("abort"):
logging.debug(f"Normalizing finish reason: {finish_reason} to cancelled")
return "cancelled"
return finish_reason
......@@ -2,10 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import inspect
import logging
import os
import signal
import sys
import time
from collections import defaultdict
from typing import Any, Awaitable, Callable, DefaultDict
import sglang as sgl
import uvloop
......@@ -43,6 +47,8 @@ from dynamo.sglang.request_handlers import (
configure_dynamo_logging()
RUN_DEFERRED_HANDLERS: Callable[[], Awaitable[None]] | None = None
async def _handle_non_leader_node(
engine: sgl.Engine,
......@@ -80,6 +86,108 @@ async def _handle_non_leader_node(
publisher.cleanup()
SignalCallback = Callable[..., Any]
def install_graceful_shutdown(
loop: asyncio.AbstractEventLoop,
runtime: Any,
*,
signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT),
) -> tuple[asyncio.Event, Callable[[], Awaitable[None]]]:
"""
Set up graceful shutdown + callback chaining.
What it does:
- Owns OS-level SIGTERM/SIGINT via signal.signal(...)
- Captures (suppresses) loop.add_signal_handler(SIGTERM/SIGINT, ...) registrations
and runs them during shutdown (sync or async)
- Calls runtime.shutdown() during shutdown (sync or async)
- Sets and returns an asyncio.Event you can await to know shutdown was requested
Returns:
(shutdown_event, run_deferred_handlers)
"""
shutdown_event = asyncio.Event()
# Deferred handlers registered via loop.add_signal_handler for these signals
deferred_handlers: DefaultDict[int, list[tuple[SignalCallback, tuple[Any, ...]]]] = defaultdict(list) # type: ignore[assignment]
# Previous OS handlers (for optional chaining)
old_os_handlers: dict[int, Any] = {}
shutdown_started = False
shutdown_signum: int | None = None
deferred_handlers_ran = False
async def run_deferred_handlers() -> None:
nonlocal deferred_handlers_ran
if not shutdown_started or deferred_handlers_ran:
return
deferred_handlers_ran = True
signums = (
[shutdown_signum]
if shutdown_signum is not None
else list(deferred_handlers.keys())
)
for sig in signums:
for cb, args in list(deferred_handlers.get(sig, [])):
try:
res = cb(*args)
if inspect.isawaitable(res):
await res
except Exception:
logging.exception("Deferred signal callback failed: %r", cb)
async def _shutdown_sequence(signum: int, frame: Any | None) -> None:
nonlocal shutdown_started, shutdown_signum
if shutdown_started:
return
shutdown_signum = signum
shutdown_started = True
logging.info("Received signal %s, starting graceful shutdown", signum)
shutdown_event.set()
try:
runtime.shutdown()
except Exception:
logging.exception("runtime.shutdown() failed")
def _schedule_shutdown(signum: int, frame: Any | None) -> None:
def _kick() -> None:
asyncio.create_task(_shutdown_sequence(signum, frame))
loop.call_soon_threadsafe(_kick)
def _os_signal_handler(signum: int, frame: Any) -> None:
# Keep the OS handler tiny; do real work in the loop thread.
_schedule_shutdown(signum, frame)
# Install OS-level handlers
for sig in signals:
old_os_handlers[sig] = signal.signal(sig, _os_signal_handler)
# Intercept loop.add_signal_handler for SIGTERM/SIGINT and defer them
orig_add = loop.add_signal_handler
def watching_add_signal_handler(sig: int, callback: SignalCallback, *args: Any):
if sig in signals:
logging.info(
"Captured loop.add_signal_handler(%s, %r, ...) (deferred).",
sig,
callback,
)
deferred_handlers[sig].append((callback, args))
return None
return orig_add(sig, callback, *args)
loop.add_signal_handler = watching_add_signal_handler # type: ignore[assignment]
return shutdown_event, run_deferred_handlers
async def worker():
config = await parse_args(sys.argv[1:])
dump_config(config.dynamo_args.dump_config_to, config)
......@@ -91,36 +199,42 @@ async def worker():
config.server_args.load_format = setup_gms(config.server_args)
dynamo_args = config.dynamo_args
runtime, _ = create_runtime(
runtime, loop = create_runtime(
store_kv=dynamo_args.store_kv,
request_plane=dynamo_args.request_plane,
event_plane=dynamo_args.event_plane,
use_kv_events=dynamo_args.use_kv_events,
)
# Set up signal handlers using signal module to allow chaining
global RUN_DEFERRED_HANDLERS
shutdown_event, RUN_DEFERRED_HANDLERS = install_graceful_shutdown(loop, runtime)
logging.info("Signal handlers set up for graceful shutdown (with chaining)")
if config.dynamo_args.image_diffusion_worker:
await init_image_diffusion(runtime, config)
elif config.dynamo_args.embedding_worker:
await init_embedding(runtime, config)
await init_embedding(runtime, config, shutdown_event)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config)
await init_multimodal_processor(runtime, config, shutdown_event)
elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
await init_multimodal_encode_worker(runtime, config, shutdown_event)
elif config.dynamo_args.multimodal_worker:
if config.serving_mode != DisaggregationMode.PREFILL:
await init_multimodal_worker(runtime, config)
await init_multimodal_worker(runtime, config, shutdown_event)
else:
await init_multimodal_prefill_worker(runtime, config)
await init_multimodal_prefill_worker(runtime, config, shutdown_event)
elif config.dynamo_args.diffusion_worker:
await init_diffusion(runtime, config)
await init_diffusion(runtime, config, shutdown_event)
elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config)
await init(runtime, config, shutdown_event)
else:
await init_prefill(runtime, config)
await init_prefill(runtime, config, shutdown_event)
async def init(runtime: DistributedRuntime, config: Config):
async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
server_args, dynamo_args = config.server_args, config.dynamo_args
# Prevent SGLang from blocking on non-leader nodes
......@@ -158,7 +272,7 @@ async def init(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(
component, engine, config, publisher, generate_endpoint
component, engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -205,12 +319,17 @@ async def init(runtime: DistributedRuntime, config: Config):
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task succesfully cancelled")
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_prefill(runtime: DistributedRuntime, config: Config):
async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
server_args, dynamo_args = config.server_args, config.dynamo_args
# Prevent SGLang from blocking on non-leader nodes
......@@ -242,7 +361,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
await _warmup_prefill_engine(engine, server_args)
handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint
component, engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -282,9 +401,14 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_diffusion(runtime: DistributedRuntime, config: Config):
async def init_diffusion(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -324,7 +448,7 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
ready_event = asyncio.Event()
handler = DiffusionWorkerHandler(
component, engine, config, publisher, generate_endpoint
component, engine, config, publisher, generate_endpoint, shutdown_event
)
handler.register_engine_routes(runtime)
......@@ -365,9 +489,14 @@ async def init_diffusion(runtime: DistributedRuntime, config: Config):
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_embedding(runtime: DistributedRuntime, config: Config):
async def init_embedding(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -387,7 +516,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = EmbeddingWorkerHandler(component, engine, config, publisher)
handler = EmbeddingWorkerHandler(
component, engine, config, publisher, shutdown_event
)
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
......@@ -423,6 +554,9 @@ async def init_embedding(runtime: DistributedRuntime, config: Config):
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
......@@ -504,9 +638,14 @@ async def init_image_diffusion(runtime: DistributedRuntime, config: Config):
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
async def init_multimodal_processor(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
......@@ -525,7 +664,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
ready_event = asyncio.Event()
handler = MultimodalProcessorHandler(component, config, encode_worker_client)
handler = MultimodalProcessorHandler(
component, config, encode_worker_client, shutdown_event
)
logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
......@@ -554,9 +695,14 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
async def init_multimodal_encode_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -574,7 +720,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
.client()
)
handler = MultimodalEncodeWorkerHandler(component, config, pd_worker_client)
handler = MultimodalEncodeWorkerHandler(
component, config, pd_worker_client, shutdown_event
)
await handler.async_init(runtime)
await pd_worker_client.wait_for_instances()
......@@ -595,9 +743,14 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal worker component.
This worker is always an internal component that should not register with
......@@ -622,9 +775,13 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
.endpoint("generate")
.client()
)
handler = MultimodalWorkerHandler(component, engine, config, prefill_client)
handler = MultimodalWorkerHandler(
component, engine, config, prefill_client, shutdown_event
)
else:
handler = MultimodalWorkerHandler(component, engine, config)
handler = MultimodalWorkerHandler(
component, engine, config, None, shutdown_event
)
await handler.async_init()
......@@ -644,9 +801,14 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Config):
async def init_multimodal_prefill_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
"""Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......@@ -658,7 +820,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = MultimodalPrefillWorkerHandler(component, engine, config)
handler = MultimodalPrefillWorkerHandler(component, engine, config, shutdown_event)
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......@@ -677,6 +839,9 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co
raise
finally:
handler.cleanup()
if RUN_DEFERRED_HANDLERS is not None:
logging.info("Running deferred handlers")
await RUN_DEFERRED_HANDLERS()
async def _warmup_prefill_engine(engine: sgl.Engine, server_args) -> None:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import Optional
......@@ -20,8 +21,9 @@ class EmbeddingWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, publisher)
super().__init__(component, engine, config, publisher, None, shutdown_event)
logging.info("Embedding worker handler initialized")
def cleanup(self):
......
......@@ -102,6 +102,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Initialize base worker handler.
......@@ -111,6 +112,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
generate_endpoint: The endpoint handle for discovery registration.
shutdown_event: Optional event to signal shutdown.
"""
# Call parent constructor
super().__init__(component, config, publisher)
......@@ -120,6 +122,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
self.config = config
self.generate_endpoint = generate_endpoint
self.publisher = publisher
self.shutdown_event = shutdown_event
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
......@@ -436,12 +439,15 @@ class BaseWorkerHandler(BaseGenerativeHandler):
async def _handle_cancellation(
self, request_id_future: asyncio.Future, context: Context
):
"""Background task to handle cancellation by monitoring context state.
"""Background task to handle cancellation and shutdown by monitoring both signals.
Args:
request_id_future: Future that will be set with the SGLang request ID
when the first response arrives.
context: Context object for cancellation handling.
Raises:
GeneratorExit: If shutdown event was triggered.
"""
try:
logging.debug(f"Cancellation monitor started for Context: {context.id()}")
......@@ -453,10 +459,34 @@ class BaseWorkerHandler(BaseGenerativeHandler):
)
logging.debug(f"Request ID future cancelled for Context: {context.id()}")
await context.async_killed_or_stopped()
# Get the cancellation future
cancellation_future = context.async_killed_or_stopped()
# Build list of futures/tasks to wait for
wait_for = [cancellation_future]
shutdown_task = None
if self.shutdown_event:
# Create task for shutdown monitoring and add to wait list
shutdown_task = asyncio.create_task(self.shutdown_event.wait())
wait_for.append(shutdown_task)
# Wait for whichever happens first
done, pending = await asyncio.wait(
wait_for,
return_when=asyncio.FIRST_COMPLETED,
)
# Cancel the pending task/future
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
logging.info(
f"Cancellation signal received for SGLang Request ID {sglang_request_id}, Context: {context.id()}"
f"Cancellation or shutdown signal received for SGLang Request ID {sglang_request_id}, Context: {context.id()}"
)
# Call abort_request on the tokenizer_manager through the engine
......@@ -475,6 +505,11 @@ class BaseWorkerHandler(BaseGenerativeHandler):
logging.error(
f"SGLang tokenizer_manager not found for abort request: {context.id()}"
)
# Check which event triggered and raise GeneratorExit if shutdown
if shutdown_task and shutdown_task in done:
raise GeneratorExit("Engine was shut down during token generation")
except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes
request_id = "unknown"
......@@ -493,9 +528,11 @@ class BaseWorkerHandler(BaseGenerativeHandler):
self, request_id_future: asyncio.Future, context: Context
) -> AsyncGenerator[asyncio.Task, None]:
"""
Context manager for monitoring request cancellation.
Context manager for monitoring request cancellation and shutdown.
Automatically creates a background task to monitor for cancellation and
cleans it up when the context exits.
shutdown events, cleaning it up when the context exits.
If shutdown event was triggered, raises GeneratorExit on exit.
Args:
request_id_future: Future that will be set with the SGLang request ID
......@@ -533,6 +570,4 @@ class BaseWorkerHandler(BaseGenerativeHandler):
except asyncio.CancelledError:
pass
else:
logging.debug(
f"Cancellation monitor task already completed for SGLang Request ID {request_id}, Context: {context.id()}"
)
cancellation_task.result()
......@@ -4,11 +4,12 @@
import asyncio
import logging
import time
from typing import Any, AsyncGenerator, Dict
from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
from dynamo._core import Component, Context
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -24,6 +25,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config: Config,
publisher: DynamoSglangPublisher,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Initialize decode worker handler.
......@@ -32,6 +34,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
shutdown_event: Optional event to signal shutdown.
generate_endpoint: The endpoint handle for discovery registration.
"""
super().__init__(
......@@ -40,6 +43,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config,
publisher,
generate_endpoint,
shutdown_event,
)
if self.serving_mode == DisaggregationMode.DECODE:
logging.info(
......@@ -222,7 +226,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
out = {}
finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
out["finish_reason"] = finish_reason["type"]
out["finish_reason"] = normalize_finish_reason(
finish_reason["type"]
)
# With stream_output=True, output_ids contains only new tokens (disjoint)
output_ids = res.get("output_ids", [])
......@@ -287,7 +293,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
text = res.get("text", "")
finish_reason = res["meta_info"]["finish_reason"]
finish_reason_type = finish_reason["type"] if finish_reason else None
finish_reason_type = (
normalize_finish_reason(finish_reason["type"])
if finish_reason
else None
)
next_count = len(text)
delta = text[count:]
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import Any, AsyncGenerator, Dict
from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
......@@ -23,6 +24,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
config: Config,
publisher: DynamoSglangPublisher = None,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Initialize diffusion worker handler.
......@@ -32,8 +34,11 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher.
generate_endpoint: The endpoint handle for discovery.
shutdown_event: Optional event to signal shutdown.
"""
super().__init__(component, engine, config, publisher, generate_endpoint)
super().__init__(
component, engine, config, publisher, generate_endpoint, shutdown_event
)
# Validate that diffusion algorithm is configured
if (
......
......@@ -3,7 +3,7 @@
import asyncio
import logging
from typing import Any, AsyncGenerator, Dict
from typing import Any, AsyncGenerator, Dict, Optional
import sglang as sgl
......@@ -23,6 +23,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config: Config,
publisher: DynamoSglangPublisher,
generate_endpoint=None,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
"""Initialize prefill worker handler.
......@@ -32,10 +33,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance.
generate_endpoint: The endpoint handle for discovery registration.
shutdown_event: Optional event to signal shutdown.
"""
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(component, engine, config, publisher, generate_endpoint)
super().__init__(
component, engine, config, publisher, generate_endpoint, shutdown_event
)
self._consume_tasks = set()
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import AsyncIterator
from typing import AsyncIterator, Optional
import torch
from sglang.srt.parser.conversation import chat_templates
......@@ -45,8 +46,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
component: Component,
config: Config,
pd_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None,
) -> None:
super().__init__(component, engine=None, config=config)
super().__init__(
component, engine=None, config=config, shutdown_event=shutdown_event
)
self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path
self.served_model_name = config.server_args.served_model_name
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import time
import uuid
from typing import Any, Dict
from typing import Any, Dict, Optional
from transformers import AutoTokenizer
......@@ -36,8 +37,11 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
component: Component,
config: Config,
encode_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine=None, config=config)
super().__init__(
component, engine=None, config=config, shutdown_event=shutdown_event
)
self.encode_worker_client = encode_worker_client
self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl")
self.model = config.server_args.model_path
......
......@@ -4,13 +4,14 @@
import asyncio
import json
import logging
from typing import AsyncIterator
from typing import AsyncIterator, Optional
import sglang as sgl
import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component, Context
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config, DisaggregationMode
from dynamo.sglang.protocol import (
DisaggSglangMultimodalRequest,
......@@ -165,7 +166,9 @@ class StreamProcessor:
if finish_reason:
output.update(
{
"finish_reason": finish_reason.get("type", "stop"),
"finish_reason": normalize_finish_reason(
finish_reason.get("type", "stop")
),
"finished": True,
}
)
......@@ -248,8 +251,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
prefill_client: Client = None,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, None)
super().__init__(component, engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......@@ -423,8 +427,14 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
Processes multimodal inputs and coordinates with decode worker.
"""
def __init__(self, component: Component, engine: sgl.Engine, config: Config):
super().__init__(component, engine, config)
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(component, engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
......
......@@ -22,6 +22,7 @@ from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.otel_tracing import build_trace_headers
......@@ -436,20 +437,6 @@ class BaseWorkerHandler(ABC):
self._lora_load_locks[lora_name] = lock
return lock
def _normalize_finish_reason(self, finish_reason: str) -> str:
"""
Normalize vLLM finish reasons to Dynamo-compatible values.
vLLM may return finish reasons that aren't recognized by Dynamo's Rust layer.
This method maps them to compatible values.
[TODO]: Remove this method and add the right code in the Rust layer.
"""
# Map vLLM's "abort" to Dynamo's "cancelled"
if finish_reason.startswith("abort"):
logging.debug(f"Normalizing finish reason: {finish_reason} to cancelled")
return "cancelled"
return finish_reason
async def load_lora(self, request=None):
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
......@@ -1223,9 +1210,7 @@ class BaseWorkerHandler(ABC):
out["top_logprobs"] = top_logprobs
if output.finish_reason:
out["finish_reason"] = self._normalize_finish_reason(
output.finish_reason
)
out["finish_reason"] = normalize_finish_reason(output.finish_reason)
out["completion_usage"] = BaseWorkerHandler._build_completion_usage(
request_output=res,
embedding_sequence_length=embedding_sequence_length,
......@@ -1438,9 +1423,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"role": "assistant",
"content": delta_text,
},
"finish_reason": self._normalize_finish_reason(
output.finish_reason
),
"finish_reason": normalize_finish_reason(output.finish_reason),
}
chunk = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment