Unverified Commit 435c8024 authored by KrishnanPrash's avatar KrishnanPrash Committed by GitHub
Browse files

refactor: fold SGLang multimodal processor into encode worker (#7517)


Signed-off-by: default avatarKrishnan Prashanth <kprashanth@nvidia.com>
parent 6b67c1c4
......@@ -18,7 +18,6 @@ Worker dispatch (main.py:60-132):
--image-diffusion-worker -> init_diffusion.init_image_diffusion()
--video-generation-worker -> init_diffusion.init_video_diffusion()
--embedding-worker -> init_embedding.init_embedding()
--multimodal-processor -> init_multimodal.init_multimodal_processor()
--multimodal-encode-worker -> init_multimodal.init_multimodal_encode_worker()
--multimodal-worker -> init_multimodal.init_multimodal_worker() or _prefill_worker()
--dllm-algorithm <algo> -> init_diffusion.init_llm_diffusion()
......@@ -43,7 +42,7 @@ Worker dispatch (main.py:60-132):
**DynamoConfig** combines `DynamoRuntimeConfig` (common flags like `--namespace`,
`--output-modalities`, `--media-output-fs-url`) with `DynamoSGLangConfig` (sglang-specific
flags like `--multimodal-processor`, `--embedding-worker`).
flags like `--multimodal-encode-worker`, `--embedding-worker`).
Key gotcha: `--output-modalities` defaults to `["text"]` globally. Image/video diffusion
workers override this in their init functions to `["image"]`/`["video"]` to ensure correct
......@@ -80,11 +79,10 @@ BaseGenerativeHandler (handler_base.py)
MultimodalPrefillWorkerHandler (multimodal/worker_handler.py)
Multimodal prefill phase. Yields bootstrap info.
MultimodalProcessorHandler (multimodal/processor_handler.py)
Front-facing. No engine. Routes to encode worker.
MultimodalEncodeWorkerHandler (multimodal/encode_worker_handler.py)
No engine. Uses MMEncoder from SGLang. NIXL for embeddings transfer.
Front-facing. No engine. Uses MMEncoder from SGLang. Receives
pre-tokenized requests (ModelInput.Tokens) from Rust frontend,
encodes images, NIXL for embeddings transfer.
```
## Engine Types by Worker
......@@ -93,8 +91,7 @@ BaseGenerativeHandler (handler_base.py)
|--------|--------|-------|
| decode, prefill, dllm, embedding | `sgl.Engine` | Full SGLang inference engine |
| multimodal-worker, multimodal-prefill | `sgl.Engine` | Plus EmbeddingsProcessor |
| multimodal-processor | None | Tokenizer only, routes to encoder |
| multimodal-encode-worker | None | `MMEncoder` from SGLang |
| multimodal-encode-worker | None | `MMEncoder` from SGLang, pre-tokenized input |
| image-diffusion-worker | `DiffGenerator` | From `sglang.multimodal_gen` |
| video-generation-worker | `DiffGenerator` | From `sglang.multimodal_gen` |
......@@ -218,7 +215,7 @@ text-to-video-diffusion.sh # 1-2 GPUs - Text-to-video (Wan2.1)
- **SimpleNamespace vs ServerArgs**: Image/video diffusion workers use SimpleNamespace
stubs. Always use `getattr(server_args, field, default)` for fields that may not exist.
- **engine=None**: Multimodal processor and encode worker pass `engine=None` to
- **engine=None**: Multimodal encode worker passes `engine=None` to
BaseWorkerHandler. Any code in the base class that touches engine must guard with
`if engine is not None`.
- **GenerationResult is a dataclass**: SGLang 0.5.9 changed `DiffGenerator.generate()`
......@@ -264,7 +261,7 @@ Checklist for adding a new worker (e.g., a new modality or serving mode):
- **Check nvidia-smi**: If a launch OOMs, check for orphaned GPU processes from prior runs.
- **SimpleNamespace stubs**: When touching args.py or code that reads server_args, always
use `getattr(server_args, field, default)` -- image/video workers don't have full ServerArgs.
- **engine can be None**: Encode-only workers (multimodal-processor, multimodal-encode-worker)
- **engine can be None**: Encode-only workers (multimodal-encode-worker)
pass engine=None. Guard any engine access in shared base class code.
- **Rebuild after Rust changes**: If changing registration (register.py interacts with Rust
bindings), rebuild: `cd lib/bindings/python && maturin develop --uv && cd <root> && uv pip install -e .`
......@@ -281,7 +278,7 @@ sglang/
backend_args.py # Dynamo-specific SGLang CLI flags
init_llm.py # init_decode(), init_prefill()
init_diffusion.py # init_llm_diffusion(), init_image_diffusion(), init_video_diffusion()
init_multimodal.py # init_multimodal_{processor,encode_worker,worker,prefill_worker}()
init_multimodal.py # init_multimodal_{encode_worker,worker,prefill_worker}()
init_embedding.py # init_embedding()
register.py # Model registration (LLM, image, video)
publisher.py # Metrics + KV event publishing
......@@ -301,7 +298,6 @@ sglang/
video_generation/
video_generation_handler.py # VideoGenerationWorkerHandler (DiffGenerator)
multimodal/
processor_handler.py # MultimodalProcessorHandler (no engine)
encode_worker_handler.py # MultimodalEncodeWorkerHandler (MMEncoder)
encode_worker_handler.py # MultimodalEncodeWorkerHandler (MMEncoder, front-facing)
worker_handler.py # MultimodalWorkerHandler + PrefillWorkerHandler
```
......@@ -262,8 +262,6 @@ async def parse_args(args: list[str]) -> Config:
and parsed_args.disaggregation_mode == "prefill"
):
endpoint = f"dyn://{namespace}.prefill.generate"
elif dynamo_config.multimodal_processor:
endpoint = f"dyn://{namespace}.processor.generate"
elif dynamo_config.multimodal_encode_worker:
endpoint = f"dyn://{namespace}.encoder.generate"
elif (
......
......@@ -41,13 +41,6 @@ class DynamoSGLangArgGroup(ArgGroup):
"the same SGLang-native pre/post processing with KV router support.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-processor",
env_var="DYN_SGL_MULTIMODAL_PROCESSOR",
default=False,
help="Run as multimodal processor component for handling multimodal requests.",
)
add_negatable_bool_argument(
g,
flag_name="--multimodal-encode-worker",
......@@ -114,7 +107,6 @@ class DynamoSGLangConfig(ConfigBase):
"""Configuration for Dynamo SGLang wrapper (SGLang-specific only)."""
use_sglang_tokenizer: bool
multimodal_processor: bool
multimodal_encode_worker: bool
multimodal_worker: bool
embedding_transfer_mode: EmbeddingTransferMode
......
......@@ -21,36 +21,35 @@ from dynamo.sglang.register import register_model_with_readiness_gate
from dynamo.sglang.request_handlers import (
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
MultimodalWorkerHandler,
)
async def init_multimodal_processor(
async def init_multimodal_encode_worker(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
) -> None:
"""Initialize multimodal processor component"""
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
shutdown_endpoints[:] = [generate_endpoint]
encode_worker_client = await runtime.endpoint(
f"{dynamo_args.namespace}.encoder.generate"
pd_worker_client = await runtime.endpoint(
f"{dynamo_args.namespace}.backend.generate"
).client()
ready_event = asyncio.Event()
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
handler = MultimodalProcessorHandler(config, encode_worker_client, shutdown_event)
await pd_worker_client.wait_for_instances()
logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
ready_event = asyncio.Event()
try:
_ = await asyncio.gather(
......@@ -67,7 +66,7 @@ async def init_multimodal_processor(
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
input_type=ModelInput.Tokens,
readiness_gate=ready_event,
),
)
......@@ -81,49 +80,6 @@ async def init_multimodal_processor(
await run_deferred_handlers()
async def init_multimodal_encode_worker(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
) -> None:
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
)
shutdown_endpoints[:] = [generate_endpoint]
pd_worker_client = await runtime.endpoint(
f"{dynamo_args.namespace}.backend.generate"
).client()
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
await pd_worker_client.wait_for_instances()
try:
await generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[
(prometheus_names.labels.MODEL, server_args.served_model_name),
(prometheus_names.labels.MODEL_NAME, server_args.served_model_name),
],
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
if run_deferred_handlers is not None:
logging.info("Running deferred handlers")
await run_deferred_handlers()
async def init_multimodal_worker(
runtime: DistributedRuntime,
config: Config,
......@@ -134,8 +90,8 @@ async def init_multimodal_worker(
"""Initialize multimodal worker component.
This worker is always an internal component that should not register with
the Frontend. Public registration is handled by the Processor component
(--multimodal-processor). For standalone serving, use init() (default).
the Frontend. Public registration is handled by the Encode Worker component
(--multimodal-encode-worker). For standalone serving, use init() (default).
"""
server_args, dynamo_args = config.server_args, config.dynamo_args
......
......@@ -22,7 +22,6 @@ from dynamo.sglang.init_llm import init_decode, init_prefill
from dynamo.sglang.init_multimodal import (
init_multimodal_encode_worker,
init_multimodal_prefill_worker,
init_multimodal_processor,
init_multimodal_worker,
)
from dynamo.sglang.shutdown import install_graceful_shutdown
......@@ -86,14 +85,6 @@ async def worker():
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_processor:
await init_multimodal_processor(
runtime,
config,
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
)
elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(
runtime,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from dynamo.sglang.multimodal_utils.multimodal_chat_processor import (
multimodal_request_to_sglang,
process_sglang_stream_response,
)
__all__ = [
"multimodal_request_to_sglang",
"process_sglang_stream_response",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, Tuple
from sglang.srt.parser.conversation import chat_templates
logger = logging.getLogger(__name__)
def multimodal_request_to_sglang(
raw_request: Any, tokenizer: Any, chat_template: str
) -> Dict[str, Any]:
conv = chat_templates[chat_template].copy()
conv.messages = []
# Convert messages into SGLang conversation
for msg in raw_request.messages:
if msg.role == "system":
conv.system_message = msg.content
elif msg.role == "user":
text_parts = []
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
elif part.type == "image_url":
text_parts.append(conv.image_token)
conv.append_message(conv.roles[0], " ".join(text_parts))
elif msg.role == "assistant":
conv.append_message(conv.roles[1], msg.content)
conv.append_message(conv.roles[1], "")
logger.debug(f"conv: {conv}")
# Tokenize and prepare input_ids
processed = tokenizer(text=conv.get_prompt(), return_tensors="pt")
input_ids = processed["input_ids"][0].tolist()
# Build the SGLang request dict
sglang_request = {
"model": raw_request.model,
"token_ids": input_ids,
"stop_conditions": {"max_tokens": raw_request.max_tokens or None},
"sampling_options": {"temperature": raw_request.temperature or 0.7},
"eos_token_ids": [tokenizer.eos_token_id],
"annotations": [],
"stream": raw_request.stream if raw_request.stream is not None else False,
}
return sglang_request
def detokenize_sglang_response(response_data: Any, tokenizer: Any) -> str:
"""
Detokenize SGLang response token IDs to text.
Args:
response_data: Dictionary containing token_ids and other response data
tokenizer: The tokenizer to use for detokenization
Returns:
String containing the detokenized text, empty string if no tokens
"""
try:
# Handle Annotated objects from Dynamo (following vLLM-like pattern)
if hasattr(response_data, "data"):
try:
import json
raw_data = response_data.data
# Handle callable data method
if callable(raw_data):
raw_data = raw_data()
response_data = (
json.loads(raw_data) if isinstance(raw_data, str) else raw_data
)
except (json.JSONDecodeError, AttributeError):
try:
raw_data = response_data.data
if callable(raw_data):
raw_data = raw_data()
response_data = {"text": str(raw_data), "finished": False}
except Exception:
response_data = {"text": str(response_data), "finished": False}
# Ensure response_data is a dictionary
if not isinstance(response_data, dict):
return str(response_data)
# Get text content - detokenize if needed
if "text" in response_data and response_data["text"]:
return response_data["text"]
elif "token_ids" in response_data and response_data["token_ids"]:
token_ids = response_data["token_ids"]
if isinstance(token_ids, list) and token_ids:
# Detokenize token IDs to get text
text_content = tokenizer.decode(token_ids, skip_special_tokens=True)
logger.debug(
f"Detokenized {len(token_ids)} tokens to: '{text_content}'"
)
return text_content
# Return empty string if no content to detokenize
return ""
except Exception as e:
logger.error(f"Failed to detokenize response: {e}")
return f"[Detokenization error: {e}]"
def process_sglang_stream_response(
response_data: Any, tokenizer: Any, accumulated_text: str = ""
) -> Tuple[str, str, bool]:
"""
Process a single SGLang streaming response with efficient detokenization.
Args:
response_data: Dictionary containing SGLang response data
tokenizer: The tokenizer to use for detokenization
accumulated_text: Previously accumulated text for context
Returns:
Tuple of (text_content, updated_accumulated_text, is_finished)
"""
try:
# Handle Annotated objects from Dynamo (following vLLM-like pattern)
if hasattr(response_data, "data"):
try:
import json
raw_data = response_data.data
# Handle callable data method
if callable(raw_data):
raw_data = raw_data()
response_data = (
json.loads(raw_data) if isinstance(raw_data, str) else raw_data
)
except (json.JSONDecodeError, AttributeError):
try:
raw_data = response_data.data
if callable(raw_data):
raw_data = raw_data()
response_data = {"text": str(raw_data), "finished": False}
except Exception:
response_data = {"text": str(response_data), "finished": False}
# Ensure response_data is a dictionary
if not isinstance(response_data, dict):
response_data = {"text": str(response_data), "finished": False}
# Detokenize the current response
text_content = detokenize_sglang_response(response_data, tokenizer)
# Update accumulated text
new_accumulated = accumulated_text + text_content
# Check if this is the final response
is_finished = response_data.get("finished", False) or response_data.get(
"finish_reason"
)
return text_content, new_accumulated, is_finished
except Exception as e:
logger.error(f"Error processing SGLang stream response: {e}")
return f"[Processing error: {e}]", accumulated_text, True
......@@ -17,7 +17,6 @@ from .llm import DecodeWorkerHandler, DiffusionWorkerHandler, PrefillWorkerHandl
from .multimodal import (
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
MultimodalWorkerHandler,
)
......@@ -41,6 +40,5 @@ __all__ = [
# Multimodal handlers
"MultimodalEncodeWorkerHandler",
"MultimodalPrefillWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
]
......@@ -2,12 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
from .encode_worker_handler import MultimodalEncodeWorkerHandler
from .processor_handler import MultimodalProcessorHandler
from .worker_handler import MultimodalPrefillWorkerHandler, MultimodalWorkerHandler
__all__ = [
"MultimodalEncodeWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler",
]
......@@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
from typing import AsyncIterator, Optional
from typing import Any, AsyncIterator, Dict, Optional
import torch
......@@ -19,7 +20,11 @@ from dynamo._core import Client, Context
from dynamo.common.multimodal import EMBEDDING_SENDER_FACTORIES
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest
from dynamo.sglang.protocol import (
MultiModalGroup,
MultiModalInput,
SglangMultimodalRequest,
)
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger = logging.getLogger(__name__)
......@@ -37,11 +42,18 @@ except ImportError as e:
DEVICE = "cpu"
IMAGE_URL_KEY = "image_url"
class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, str]):
"""
Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker.
Receives pre-tokenized requests from the Rust frontend (ModelInput.Tokens)
with token_ids and multi_modal_data containing image URLs. Encodes images
via MMEncoder, expands placeholder tokens, transfers embeddings via NIXL,
and forwards to the PD worker.
"""
def __init__(
......@@ -113,49 +125,78 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
def cleanup(self) -> None:
pass
def _extract_image_urls(self, request: Dict[str, Any]) -> list[str]:
"""
Extract image URLs from the multi_modal_data field of a PreprocessedRequest.
The Rust frontend populates multi_modal_data with the format:
{"image_url": [{"Url": "https://..."}, ...]}
"""
mm_data = request.get("multi_modal_data")
if not mm_data:
raise ValueError("multi_modal_data is required for the encode worker.")
image_items = mm_data.get(IMAGE_URL_KEY)
if not image_items:
raise ValueError("multi_modal_data must contain image_url entries.")
image_urls: list[str] = []
for item in image_items:
if isinstance(item, str):
image_urls.append(item)
elif isinstance(item, dict) and "Url" in item:
image_urls.append(item["Url"])
elif isinstance(item, dict) and "Decoded" in item:
raise ValueError(
"Frontend-decoded media (Decoded variant) is incompatible "
"with the multimodal encode worker. The encode worker "
"requires image URLs to run vision encoding via MMEncoder. "
"Disable --frontend-decoding when using EPD serving."
)
else:
raise ValueError(f"Unsupported multimodal data variant: {item}")
return image_urls
@_nvtx.range_decorator("mm:enc:generate", color="blue")
async def generate(
self, request: SglangMultimodalRequest, context: Context
) -> AsyncIterator[str]:
self, raw_request: Dict[str, Any], context: Context
) -> AsyncIterator[Dict[str, Any]]:
"""
Generate precomputed embeddings for multimodal input.
Encode images from a pre-tokenized multimodal request, expand placeholder
tokens, transfer embeddings via NIXL, and stream PD worker responses.
The Rust frontend (ModelInput.Tokens) sends a PreprocessedRequest dict
with token_ids and multi_modal_data. This handler:
1. Extracts image URLs from multi_modal_data.
2. Runs vision encoding via MMEncoder.
3. Expands image placeholder tokens to match patch counts.
4. Creates a NIXL descriptor for embedding transfer.
5. Forwards the request to the PD worker and streams responses back.
Args:
request: Multimodal request with image/video data.
raw_request: PreprocessedRequest dict from the Rust frontend.
context: Context object for cancellation handling.
"""
if not isinstance(request, SglangMultimodalRequest):
if isinstance(request, str):
request = SglangMultimodalRequest.model_validate_json(request)
else:
request = SglangMultimodalRequest.model_validate(request)
# The following steps encode the requested image for SGLang:
# 1. Pass the image URL to MMEncoder which loads, preprocesses, and
# runs the vision encoder.
# 2. Expand each image placeholder token to match patch count.
# 3. Create a single NIXL descriptor for concatenated embeddings.
# 4. Send request + metadata to downstream worker.
# 5. Stream the downstream worker's response back to the caller.
if isinstance(raw_request, str):
raw_request = json.loads(raw_request)
# Extract image URLs from the frontend's multi_modal_data
image_urls = self._extract_image_urls(raw_request)
# Build MultiModalGroup objects for the downstream SglangMultimodalRequest
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(image_url=url))
for url in image_urls
]
# Build SglangMultimodalRequest from the pre-tokenized request
request = SglangMultimodalRequest(
request=raw_request,
multimodal_inputs=multimodal_groups,
)
try:
multimodal_groups = request.multimodal_inputs
if not multimodal_groups:
raise ValueError("multimodal_inputs is required for the encode worker.")
image_urls = []
for idx, mm_group in enumerate(multimodal_groups):
mm_input = mm_group.multimodal_input
if not mm_input or not mm_input.image_url:
raise ValueError(
f"image_url is required for the encode worker (index={idx})."
)
if mm_input.video_url is not None:
raise NotImplementedError(
"video_url encoding is not supported in SGLang encode worker"
)
image_urls.append(mm_input.image_url)
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
image_grid_dim, precomputed_embeddings = await self.encoder._encode(
image_urls
......@@ -275,8 +316,23 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler[SglangMultimodalRequest, s
request.model_dump_json()
)
# Parse PD worker responses and yield as LLMEngineOutput-
# compatible dicts for the Rust frontend to post-process.
async for response in response_generator:
yield response.data() if hasattr(response, "data") else str(response)
raw = response.data() if hasattr(response, "data") else str(response)
try:
data = json.loads(raw) if isinstance(raw, str) else raw
except json.JSONDecodeError:
logger.warning("Non-JSON response from PD worker: %r", raw[:200])
data = {"token_ids": [], "text": raw}
# Strip the internal 'finished' flag — the Rust frontend
# uses 'finish_reason' (present when finished=True).
data.pop("finished", None)
# Remove empty 'text' so the Rust frontend detokenizes
# from token_ids instead of using the empty string.
if not data.get("text"):
data.pop("text", None)
yield data
await transfer_future
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import logging
import time
import uuid
from typing import Any, AsyncGenerator, Dict, Optional
from transformers import AutoTokenizer
from dynamo._core import Client, Context
from dynamo.sglang.args import Config
from dynamo.sglang.multimodal_utils import (
multimodal_request_to_sglang,
process_sglang_stream_response,
)
from dynamo.sglang.protocol import (
MultiModalGroup,
MultiModalInput,
MultiModalRequest,
PreprocessedRequest,
SglangMultimodalRequest,
)
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger = logging.getLogger(__name__)
class MultimodalProcessorHandler(BaseWorkerHandler[MultiModalRequest, Dict[str, Any]]):
"""
Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker.
"""
def __init__(
self,
config: Config,
encode_worker_client: Client,
shutdown_event: Optional[asyncio.Event] = None,
):
super().__init__(engine=None, config=config, shutdown_event=shutdown_event)
self.encode_worker_client = encode_worker_client
self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl")
self.model = config.server_args.model_path
# Initialize tokenizer for the model
self.tokenizer = AutoTokenizer.from_pretrained(
self.model,
trust_remote_code=True,
use_fast=True,
padding_side="left",
truncation_side="left",
)
def cleanup(self):
pass
async def generate(
self, raw_request: MultiModalRequest, context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Process multimodal request and forward to encode worker.
Args:
raw_request: Raw multimodal request to process.
context: Context object for cancellation handling.
"""
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
image_urls: list[str] = []
video_url: str | None = None
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
if video_url is not None:
raise ValueError("Cannot provide both image and video URLs")
image_urls.append(item.image_url.url)
elif item.type == "video_url":
if image_urls:
raise ValueError("Cannot provide both image and video URLs")
if video_url is not None:
raise ValueError("Multiple video URLs are not supported")
video_url = item.video_url.url
if not image_urls and video_url is None:
raise ValueError("Either image URL or video URL is required")
multimodal_groups: list[MultiModalGroup] = []
if image_urls:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(image_url=url))
for url in image_urls
]
elif video_url is not None:
multimodal_groups = [
MultiModalGroup(multimodal_input=MultiModalInput(video_url=video_url))
]
async for response in self._generate(raw_request, multimodal_groups):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
yield response
async def _generate(
self,
raw_request: MultiModalRequest,
multimodal_groups: list[MultiModalGroup],
):
# Generate a unique request ID for tracking
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
# Create SGLang conversation prompt
sglang_request = multimodal_request_to_sglang(
raw_request, self.tokenizer, self.chat_template
)
worker_request = SglangMultimodalRequest(
request=PreprocessedRequest(**sglang_request),
multimodal_inputs=multimodal_groups,
)
# Send to encoder worker
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
# Process and yield SGLang responses
finished_sent = False
accumulated_text = ""
async for resp in response_generator:
try:
# Handle Annotated response objects from Dynamo (like vLLM pattern but for SGLang)
if hasattr(resp, "data"):
# Extract data from Dynamo Annotated response
raw_data = resp.data
if callable(raw_data):
raw_data = raw_data()
if isinstance(raw_data, str):
try:
response_data = json.loads(raw_data)
except json.JSONDecodeError:
response_data = {"text": raw_data, "finished": False}
else:
response_data = raw_data
elif isinstance(resp, str):
try:
response_data = json.loads(resp)
except json.JSONDecodeError:
response_data = {"text": resp, "finished": False}
else:
response_data = resp
# Use SGLang chat_processor for detokenization
(
text_content,
accumulated_text,
is_finished,
) = process_sglang_stream_response(
response_data, self.tokenizer, accumulated_text
)
# Create OpenAI-compatible response (following vLLM-like pattern but for SGLang)
if text_content or is_finished:
choice: Dict[str, Any] = {
"index": 0,
"delta": {},
"finish_reason": None,
}
delta: Dict[str, str] = choice["delta"] # Type-safe access
# Add role for first message or when there's content
if text_content and not finished_sent:
delta["role"] = "assistant"
# Add content if available
if text_content:
delta["content"] = text_content
# Set finish reason if completed
if is_finished:
choice["finish_reason"] = response_data.get(
"finish_reason", "stop"
)
if not finished_sent and not text_content:
# Final chunk needs role if it's the first chunk
delta["role"] = "assistant"
response_json = {
"id": f"chatcmpl-{request_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model,
"choices": [choice],
}
# Add usage only for final response
if is_finished:
response_json["usage"] = {
"prompt_tokens": 0,
"completion_tokens": len(accumulated_text.split())
if accumulated_text
else 0,
"total_tokens": len(accumulated_text.split())
if accumulated_text
else 0,
}
yield response_json
if is_finished:
finished_sent = True
break
except Exception as e:
logger.error(f"Error processing SGLang response: {e}")
error_response = {
"id": f"chatcmpl-{request_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {str(e)}",
},
"finish_reason": "stop",
}
],
}
yield error_response
break
......@@ -18,8 +18,7 @@ Dynamo SGLang uses SGLang's native argument parser -- all SGLang engine argument
| **Decode** *(default)* | Standard LLM inference (aggregated or disaggregated decode) |
| **Prefill** | Disaggregated prefill phase (`--disaggregation-mode prefill`) |
| **Embedding** | Text embedding models (`--embedding-worker`) |
| **Multimodal Processor** | HTTP entry point for multimodal, OpenAI-to-SGLang conversion (`--multimodal-processor`) |
| **Multimodal Encode** | Vision encoder and embeddings generation (`--multimodal-encode-worker`) |
| **Multimodal Encode** | Frontend-facing: vision encoding, embeddings generation (`--multimodal-encode-worker`) |
| **Multimodal Worker** | LLM inference with multimodal data (`--multimodal-worker`) |
| **Multimodal Prefill** | Prefill phase for multimodal disaggregation (`--multimodal-worker --disaggregation-mode prefill`) |
| **Image Diffusion** | Image generation via DiffGenerator (`--image-diffusion-worker`) |
......@@ -40,8 +39,7 @@ These arguments are added by Dynamo on top of SGLang's native arguments.
| `--dyn-reasoning-parser` | `DYN_REASONING_PARSER` | `None` | Reasoning parser for chain-of-thought models |
| `--custom-jinja-template` | `DYN_CUSTOM_JINJA_TEMPLATE` | `None` | Custom chat template path (incompatible with `--use-sglang-tokenizer`) |
| `--embedding-worker` | `DYN_SGL_EMBEDDING_WORKER` | `false` | Run as embedding worker (also sets SGLang's `--is-embedding`) |
| `--multimodal-processor` | `DYN_SGL_MULTIMODAL_PROCESSOR` | `false` | Run as [multimodal](../../features/multimodal/multimodal-sglang.md) processor |
| `--multimodal-encode-worker` | `DYN_SGL_MULTIMODAL_ENCODE_WORKER` | `false` | Run as multimodal encode worker |
| `--multimodal-encode-worker` | `DYN_SGL_MULTIMODAL_ENCODE_WORKER` | `false` | Run as [multimodal](../../features/multimodal/multimodal-sglang.md) encode worker (frontend-facing) |
| `--multimodal-worker` | `DYN_SGL_MULTIMODAL_WORKER` | `false` | Run as multimodal LLM worker |
| `--image-diffusion-worker` | `DYN_SGL_IMAGE_DIFFUSION_WORKER` | `false` | Run as [image diffusion](sglang-diffusion.md#image-diffusion) worker |
| `--video-generation-worker` | `DYN_SGL_VIDEO_GENERATION_WORKER` | `false` | Run as [video generation](sglang-diffusion.md#video-generation) worker |
......
......@@ -36,8 +36,7 @@ SGLang supports EPD, E/PD, and E/P/D patterns. See [Multimodal Architecture Patt
| Component | Flag | Purpose |
|-----------|------|---------|
| Processor | `--multimodal-processor` | HTTP entry, OpenAI→SGLang conversion |
| Encode Worker | `--multimodal-encode-worker` | Vision encoder, embeddings generation |
| Encode Worker | `--multimodal-encode-worker` | Frontend-facing, vision encoding, embeddings generation (Rust frontend tokenizes) |
| PD Worker | `--multimodal-worker` | Prefill + Decode with embeddings |
| Decode Worker | `--multimodal-worker --serving-mode=decode` | Entry point for disaggregation |
| Prefill Worker | `--multimodal-worker --serving-mode=prefill` | Called by Decode, bootstrap coordination |
......@@ -118,25 +117,20 @@ curl http://localhost:8000/v1/chat/completions \
### Components
- workers:
- [MultimodalEncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py) for encoding
- [MultimodalEncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py) for image encoding and embeddings generation
- [MultimodalWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py) for prefilling and decoding.
- processor: [MultimodalProcessorHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py)
- tokenizes the prompt using the chat template
- passes the text and image url to the MultimodalEncodeWorker.
### Workflow
The `MultimodalEncodeWorker` downloads and encodes the image and passes the embeddings to the MultimodalWorker. The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface. The `MultimodalWorker` then prefills and decodes the prompt in the same engine, as in the [LLM aggregated serving](../../backends/sglang/README.md) example. Only the processor is registered to the Dynamo frontend as an available endpoint. Workers do NOT register - they are internal components and communicate via NATS.
The Rust frontend tokenizes the request and extracts image URLs into `multi_modal_data`. The `MultimodalEncodeWorker` receives the pre-tokenized request, downloads and encodes the image, and passes the embeddings to the MultimodalWorker. The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface. The `MultimodalWorker` then prefills and decodes the prompt in the same engine, as in the [LLM aggregated serving](../../backends/sglang/README.md) example. Only the encode worker is registered to the Dynamo frontend as an available endpoint. The PD worker does NOT register - it is an internal component and communicates via NATS.
```mermaid
flowchart LR
HTTP --> processor
processor --tokenized request + image_url--> encode_worker
HTTP --> encode_worker
encode_worker --request + embeddings--> worker
worker -.-> encode_worker
encode_worker -.-> processor
processor -.-> HTTP
encode_worker -.-> HTTP
```
......@@ -181,26 +175,23 @@ curl http://localhost:8000/v1/chat/completions \
### Components
- workers:
- [MultimodalEncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py) for encoding
- [MultimodalEncodeWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py) for image encoding and embeddings generation
- [MultimodalWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py) for decoding
- [MultimodalPrefillWorkerHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py) for prefilling
- processor: [MultimodalProcessorHandler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py) tokenizes the prompt and passes it to the MultimodalEncodeWorker.
### Workflow
In models like Qwen2.5-VL, embeddings are only required during the prefill stage. The image embeddings are transferred via NIXL from the Encode Worker to the Decode Worker (the entry point for disaggregation), which then coordinates with the Prefill Worker. The Prefill Worker processes the embeddings and forwards the KV cache back to the Decode Worker for token generation.
In models like Qwen2.5-VL, embeddings are only required during the prefill stage. The Rust frontend tokenizes and extracts image URLs. The `MultimodalEncodeWorker` receives the pre-tokenized request, encodes images, and transfers embeddings via NIXL to the Decode Worker (the entry point for disaggregation), which then coordinates with the Prefill Worker. The Prefill Worker processes the embeddings and forwards the KV cache back to the Decode Worker for token generation.
```mermaid
flowchart LR
HTTP --> processor
processor --tokenized request + image_url--> encode_worker
HTTP --> encode_worker
encode_worker --request + embeddings--> worker
worker --request + embeddings--> prefill_worker
prefill_worker --KV Cache--> worker
encode_worker -.-> processor
worker -.-> encode_worker
processor -.-> HTTP
encode_worker -.-> HTTP
```
### Launch
......@@ -458,10 +449,8 @@ SGLang multimodal **only supports image-based vision-language models**:
| File | Description |
|------|-------------|
| `components/src/dynamo/sglang/main.py` | Component initialization, only Processor registers |
| `components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py` | Processor implementation, OpenAI→SGLang |
| `components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py` | Vision encoder, embeddings generation |
| `components/src/dynamo/sglang/main.py` | Component initialization, Encode Worker registers |
| `components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py` | Frontend-facing: vision encoding, embeddings generation (receives pre-tokenized input) |
| `components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py` | PD/Prefill/Decode workers, NIXL read |
| `components/src/dynamo/sglang/multimodal_utils/multimodal_chat_processor.py` | Chat template processing |
| `components/src/dynamo/sglang/protocol.py` | Request/response data structures |
| `components/src/dynamo/sglang/register.py` | Registration logic (only called for Processor) |
| `components/src/dynamo/sglang/register.py` | Registration logic (called for Encode Worker) |
......@@ -119,13 +119,10 @@ print_launch_banner --multimodal "Launching Disaggregated Multimodal E/P/D" "$MO
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python3 -m dynamo.frontend &
# run SGLang multimodal processor
python3 -m dynamo.sglang --multimodal-processor --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal encode worker
# run SGLang multimodal encode worker (frontend-facing: encodes images, routes to worker)
echo "Starting encode worker on GPU $DYN_ENCODE_WORKER_GPU (GPU mem: $DYN_ENCODE_GPU_MEM)..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" $ENCODE_EXTRA_ARGS &
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" --skip-tokenizer-init $ENCODE_EXTRA_ARGS &
if [[ "$SINGLE_GPU" == "true" ]]; then
# Wait for encode worker to initialize before starting prefill worker.
......
......@@ -111,13 +111,10 @@ print_launch_banner --multimodal "Launching Multimodal E/PD ($GPU_LABEL)" "$MODE
# dynamo.frontend accepts either --http-port flag or DYN_HTTP_PORT env var (defaults to 8000)
python3 -m dynamo.frontend &
# run SGLang multimodal processor
python3 -m dynamo.sglang --multimodal-processor --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal encode worker
# run SGLang multimodal encode worker (frontend-facing: encodes images, routes to worker)
echo "Starting encode worker on GPU $DYN_ENCODE_WORKER_GPU (GPU mem: $DYN_ENCODE_GPU_MEM)..."
DYN_SYSTEM_PORT=${DYN_SYSTEM_PORT1:-8081} \
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" $ENCODE_EXTRA_ARGS &
CUDA_VISIBLE_DEVICES=$DYN_ENCODE_WORKER_GPU python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" $SERVED_MODEL_ARG --chat-template "$CHAT_TEMPLATE" --skip-tokenizer-init $ENCODE_EXTRA_ARGS &
if [[ "$SINGLE_GPU" == "true" ]]; then
# Wait for encode worker to initialize before starting PD worker.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment