Unverified Commit 91ddf418 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: v0 diffusion handler support (#5533)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 224f63f5
...@@ -154,6 +154,10 @@ class DynamoArgs: ...@@ -154,6 +154,10 @@ class DynamoArgs:
# embedding options # embedding options
embedding_worker: bool = False embedding_worker: bool = False
# diffusion language model options (derived from server_args.dllm_algorithm)
diffusion_worker: bool = False
# config dump options # config dump options
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
# local indexer option # local indexer option
...@@ -535,6 +539,9 @@ async def parse_args(args: list[str]) -> Config: ...@@ -535,6 +539,9 @@ async def parse_args(args: list[str]) -> Config:
f"Derived use_kv_events={use_kv_events} from kv_events_config={server_args.kv_events_config}" f"Derived use_kv_events={use_kv_events} from kv_events_config={server_args.kv_events_config}"
) )
# Auto-detect diffusion worker mode if dllm_algorithm
diffusion_worker = server_args.dllm_algorithm is not None
dynamo_args = DynamoArgs( dynamo_args = DynamoArgs(
namespace=parsed_namespace, namespace=parsed_namespace,
component=parsed_component_name, component=parsed_component_name,
...@@ -551,6 +558,7 @@ async def parse_args(args: list[str]) -> Config: ...@@ -551,6 +558,7 @@ async def parse_args(args: list[str]) -> Config:
multimodal_encode_worker=parsed_args.multimodal_encode_worker, multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker, multimodal_worker=parsed_args.multimodal_worker,
embedding_worker=parsed_args.embedding_worker, embedding_worker=parsed_args.embedding_worker,
diffusion_worker=diffusion_worker,
dump_config_to=parsed_args.dump_config_to, dump_config_to=parsed_args.dump_config_to,
enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true", enable_local_indexer=str(parsed_args.enable_local_indexer).lower() == "true",
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
......
...@@ -24,6 +24,7 @@ from dynamo.sglang.publisher import setup_prometheus_registry, setup_sgl_metrics ...@@ -24,6 +24,7 @@ from dynamo.sglang.publisher import setup_prometheus_registry, setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_readiness_gate from dynamo.sglang.register import register_llm_with_readiness_gate
from dynamo.sglang.request_handlers import ( from dynamo.sglang.request_handlers import (
DecodeWorkerHandler, DecodeWorkerHandler,
DiffusionWorkerHandler,
EmbeddingWorkerHandler, EmbeddingWorkerHandler,
MultimodalEncodeWorkerHandler, MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler, MultimodalPrefillWorkerHandler,
...@@ -97,6 +98,8 @@ async def worker(): ...@@ -97,6 +98,8 @@ async def worker():
await init_multimodal_worker(runtime, config) await init_multimodal_worker(runtime, config)
else: else:
await init_multimodal_prefill_worker(runtime, config) await init_multimodal_prefill_worker(runtime, config)
elif config.dynamo_args.diffusion_worker:
await init_diffusion(runtime, config)
elif config.serving_mode != DisaggregationMode.PREFILL: elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config) await init(runtime, config)
else: else:
...@@ -267,6 +270,91 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -267,6 +270,91 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup() handler.cleanup()
async def init_diffusion(runtime: DistributedRuntime, config: Config):
"""Initialize diffusion language model worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
logging.info(
f"Initializing diffusion worker with algorithm: {server_args.dllm_algorithm}"
)
if server_args.dllm_algorithm_config:
logging.info(
f"Using diffusion algorithm config: {server_args.dllm_algorithm_config}"
)
# Prevent SGLang from blocking on non-leader nodes
if server_args.node_rank >= 1:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# Handle non-leader nodes (multi-node parallelism)
if server_args.node_rank >= 1:
await _handle_non_leader_node(engine, generate_endpoint)
return
# Setup metrics publisher
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
)
# Register Prometheus metrics callback if enabled
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DiffusionWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
handler.register_engine_routes(runtime)
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
).to_dict()
logging.info(
f"Registering diffusion model with endpoint types: {dynamo_args.dyn_endpoint_types}"
)
try:
# Start endpoint and register model
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=metrics_labels,
health_check_payload=health_check_payload,
),
register_llm_with_readiness_gate(
engine,
generate_endpoint,
server_args,
dynamo_args,
output_type=parse_endpoint_types(dynamo_args.dyn_endpoint_types),
readiness_gate=ready_event,
),
)
except Exception as e:
logging.error(f"Failed to serve diffusion endpoints: {e}")
raise
finally:
metrics_task.cancel()
try:
await metrics_task
except asyncio.CancelledError:
logging.info("Metrics task successfully cancelled")
pass
handler.cleanup()
async def init_embedding(runtime: DistributedRuntime, config: Config): async def init_embedding(runtime: DistributedRuntime, config: Config):
"""Initialize embedding worker component""" """Initialize embedding worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args server_args, dynamo_args = config.server_args, config.dynamo_args
......
...@@ -8,7 +8,7 @@ from .embedding import EmbeddingWorkerHandler ...@@ -8,7 +8,7 @@ from .embedding import EmbeddingWorkerHandler
from .handler_base import BaseWorkerHandler from .handler_base import BaseWorkerHandler
# LLM handlers # LLM handlers
from .llm import DecodeWorkerHandler, PrefillWorkerHandler from .llm import DecodeWorkerHandler, DiffusionWorkerHandler, PrefillWorkerHandler
# Multimodal handlers # Multimodal handlers
from .multimodal import ( from .multimodal import (
...@@ -22,6 +22,7 @@ __all__ = [ ...@@ -22,6 +22,7 @@ __all__ = [
"BaseWorkerHandler", "BaseWorkerHandler",
# LLM handlers # LLM handlers
"DecodeWorkerHandler", "DecodeWorkerHandler",
"DiffusionWorkerHandler",
"PrefillWorkerHandler", "PrefillWorkerHandler",
# Embedding handlers # Embedding handlers
"EmbeddingWorkerHandler", "EmbeddingWorkerHandler",
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from .decode_handler import DecodeWorkerHandler from .decode_handler import DecodeWorkerHandler
from .diffusion_handler import DiffusionWorkerHandler
from .prefill_handler import PrefillWorkerHandler from .prefill_handler import PrefillWorkerHandler
__all__ = [ __all__ = [
"DecodeWorkerHandler", "DecodeWorkerHandler",
"DiffusionWorkerHandler",
"PrefillWorkerHandler", "PrefillWorkerHandler",
] ]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, AsyncGenerator, Dict
import sglang as sgl
from dynamo._core import Component, Context
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
from dynamo.sglang.request_handlers.llm.decode_handler import DecodeWorkerHandler
class DiffusionWorkerHandler(DecodeWorkerHandler):
"""
Handler for diffusion language model workers.
"""
def __init__(
self,
component: Component,
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher = None,
generate_endpoint=None,
) -> None:
"""Initialize diffusion worker handler.
Args:
component: The Dynamo runtime component.
engine: SGLang engine with diffusion algorithm configured.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher.
generate_endpoint: The endpoint handle for discovery.
"""
super().__init__(component, engine, config, publisher, generate_endpoint)
# Validate that diffusion algorithm is configured
if (
not hasattr(engine.tokenizer_manager.server_args, "dllm_algorithm")
or engine.tokenizer_manager.server_args.dllm_algorithm is None
):
logging.error(
"SGLang engine does not have dllm_algorithm configured. "
"Diffusion LM behavior may not be active."
"Please check the SGLang engine configuration."
)
else:
logging.info(
f"Diffusion worker initialized with algorithm: "
f"{engine.tokenizer_manager.server_args.dllm_algorithm}"
)
async def generate(
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate response using diffusion LM.
Args:
request: Request dict with input and sampling parameters.
context: Context object for cancellation handling.
Yields:
Response dicts with token_ids or OpenAI-formatted chunks.
"""
logging.debug(
f"Starting diffusion generation for request {context.id()}, "
f"input_tokens={len(request.get('token_ids', []))}"
)
# Get input parameters (tokens or text)
input_param = self._get_input_param(request)
# Build sampling parameters
sampling_params = self._build_sampling_params(request)
# Generate trace info if tracing is enabled
trace_header = self._get_trace_header(context) if self.enable_trace else None
trace_id = context.id() if trace_header else None
async_gen = await self.engine.async_generate(
**input_param,
sampling_params=sampling_params,
stream=True, # Always stream for Dynamo
external_trace_header=trace_header,
rid=trace_id,
)
# Process stream output (token-based or text-based)
if self.skip_tokenizer_init:
async for out in self._process_token_stream(async_gen, context):
yield out
else:
async for out in self._process_text_stream(async_gen, context):
yield out
def cleanup(self) -> None:
super().cleanup()
<!--
SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# Running Diffusion LMs with SGLang
Diffusion Language Models (Diffusion LMs) are a class of generative models that use diffusion processes for text generation. This guide shows how to deploy diffusion models like LLaDA2.0 using SGLang as the backend with Dynamo. Diffusion LMs work differently from autoregressive models - they iteratively refine generated text through a diffusion process.
## Launch the Deployment
### Using the Launch Script (Recommended)
The easiest way to start the diffusion LM service is using the provided launch script:
```bash
bash examples/backends/sglang/launch/diffusion_llada.sh
```
### Manual Launch Steps
If you prefer to launch components manually:
**Start frontend**
```bash
python -m dynamo.frontend --http-port 8001 &
```
**Run diffusion worker**
```bash
export CUDA_VISIBLE_DEVICES=0,1
python -m dynamo.sglang \
--model-path inclusionAI/LLaDA2.0-mini-preview \
--tp-size 2 \
--skip-tokenizer-init \
--trust-remote-code \
--endpoint dyn://dynamo.backend.generate \
--enable-metrics \
--disable-cuda-graph \
--disable-overlap-schedule \
--attention-backend triton \
--dllm-algorithm LowConfidence
```
## Diffusion Algorithms
The diffusion worker supports different algorithms for the iterative refinement process:
- **LowConfidence** (default): Refines tokens with low confidence scores
- **HighConfidence**: Alternative algorithm focusing on high-confidence refinement
For more details on diffusion algorithms and configuration options, refer to the [SGLang Diffusion Language Models documentation](https://github.com/sgl-project/sglang/blob/main/docs/supported_models/diffusion_language_models.md).
## Testing the Deployment
Once deployed, you can test the service using curl:
```bash
curl -X POST http://localhost:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "inclusionAI/LLaDA2.0-mini-preview",
"messages": [
{
"role": "user",
"content": "Hello! How are you?"
}
],
"temperature": 0.7,
"max_tokens": 512
}'
```
Or use the completions endpoint:
```bash
curl -X POST http://localhost:8001/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "inclusionAI/LLaDA2.0-mini-preview",
"prompt": "Once upon a time",
"max_tokens": 256
}'
```
\ No newline at end of file
...@@ -60,6 +60,7 @@ ...@@ -60,6 +60,7 @@
backends/sglang/expert-distribution-eplb.md backends/sglang/expert-distribution-eplb.md
backends/sglang/gpt-oss.md backends/sglang/gpt-oss.md
backends/sglang/diffusion-lm.md
backends/sglang/profiling.md backends/sglang/profiling.md
backends/sglang/sgl-hicache-example.md backends/sglang/sgl-hicache-example.md
backends/sglang/sglang-disaggregation.md backends/sglang/sglang-disaggregation.md
......
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
# Setup cleanup trap
cleanup() {
echo "Cleaning up background processes..."
kill $FRONTEND_PID 2>/dev/null || true
wait $FRONTEND_PID 2>/dev/null || true
echo "Cleanup complete."
}
trap cleanup EXIT INT TERM
# Model configuration
MODEL_PATH="inclusionAI/LLaDA2.0-mini-preview"
# Diffusion algorithm configuration
DLLM_ALGORITHM="${DLLM_ALGORITHM:-LowConfidence}"
DLLM_ALGORITHM_CONFIG="${DLLM_ALGORITHM_CONFIG:-}" # Optional: path to YAML config file
# Dynamo configuration
NAMESPACE="${NAMESPACE:-dynamo}"
COMPONENT="${COMPONENT:-backend}"
ENDPOINT="${ENDPOINT:-generate}"
HTTP_PORT="${HTTP_PORT:-8001}"
TP_SIZE="${TP_SIZE:-1}"
echo "=========================================="
echo "Launching Diffusion LM Worker (LLaDA2.0)"
echo "=========================================="
echo "Model: $MODEL_PATH"
echo "Namespace: $NAMESPACE"
echo "Component: $COMPONENT"
echo "Frontend Port: $HTTP_PORT"
echo "TP Size: $TP_SIZE"
echo "Diffusion Algorithm: ${DLLM_ALGORITHM:-LowConfidence}"
echo "Algorithm Config: ${DLLM_ALGORITHM_CONFIG:-default}"
echo "=========================================="
# Launch frontend (OpenAI-compatible API server)
echo "Starting Dynamo Frontend on port $HTTP_PORT..."
python -m dynamo.frontend \
--http-port "$HTTP_PORT" &
FRONTEND_PID=$!
# Wait for frontend to start
sleep 2
# Launch diffusion worker
echo "Starting Diffusion LM Worker..."
# Build the command with required arguments
export CUDA_VISIBLE_DEVICES=0
CMD="python -m dynamo.sglang \
--model-path $MODEL_PATH \
--tp-size $TP_SIZE \
--skip-tokenizer-init \
--trust-remote-code \
--endpoint dyn://${NAMESPACE}.${COMPONENT}.${ENDPOINT} \
--enable-metrics \
--disable-cuda-graph \
--disable-overlap-schedule \
--attention-backend triton \
--dllm-algorithm $DLLM_ALGORITHM"
# Add optional algorithm config if provided
if [ -n "$DLLM_ALGORITHM_CONFIG" ]; then
CMD="$CMD --dllm-algorithm-config $DLLM_ALGORITHM_CONFIG"
fi
# Execute the command
eval $CMD
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment