"vscode:/vscode.git/clone" did not exist on "92f06b0e7ff03bd02cc6a56f9ba9258917dc9dae"
Unverified Commit 62661cce authored by Kris Hung's avatar Kris Hung Committed by GitHub
Browse files

feat: Add multimodal EPD for Sglang (#3230)


Signed-off-by: default avatarkrishung5 <krish@nvidia.com>
parent 91ba9026
......@@ -39,6 +39,7 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
| [**Conditional Disaggregation**](../../../docs/architecture/disagg_serving.md#conditional-disaggregation) | 🚧 | WIP [PR](https://github.com/sgl-project/sglang/pull/7730) |
| [**KV-Aware Routing**](../../../docs/architecture/kv_cache_routing.md) | ✅ | |
| [**SLA-Based Planner**](../../../docs/architecture/sla_planner.md) | ✅ | |
| [**Multimodal EPD Disaggregation**](docs/multimodal_epd.md) | ✅ | |
| [**Load Based Planner**](../../../docs/architecture/load_planner.md) | ❌ | Planned |
| [**KVBM**](../../../docs/architecture/kvbm_architecture.md) | ❌ | Planned |
......@@ -254,6 +255,9 @@ Below we provide a selected list of advanced examples. Please open up an issue i
### Hierarchical Cache (HiCache)
- **[Enable SGLang Hierarchical Cache (HiCache)](docs/sgl-hicache-example.md)**
### Multimodal Encode-Prefill-Decode (EPD) Disaggregation with NIXL
- **[Run a multimodal model with EPD Disaggregation](docs/multimodal_epd.md)**
## Deployment
We currently provide deployment examples for Kubernetes and SLURM.
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
-->
# Encode-Prefill-Decode (EPD) Flow with NIXL
For high-performance multimodal inference with large embeddings, Dynamo supports a specialized **Encode-Prefill-Decode (EPD)** flow using **NIXL (RDMA)** for zero-copy tensor transfer.
## Use the Latest Release
We recommend using the latest stable release of dynamo to avoid breaking changes:
[![GitHub Release](https://img.shields.io/github/v/release/ai-dynamo/dynamo)](https://github.com/ai-dynamo/dynamo/releases/latest)
You can find the latest release [here](https://github.com/ai-dynamo/dynamo/releases/latest) and check out the corresponding branch with:
```bash
git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
```
## Multimodal Aggregated Serving
### Components
- workers: For aggregated serving, we have two workers, [MultimodalEncodeWorker](src/dynamo/sglang/request_handlers/multimodal_encode_worker_handler.py) for encoding and [MultimodalWorker](src/dynamo/sglang/request_handlers/multimodal_worker_handler.py) for prefilling and decoding.
- processor: Tokenizes the prompt and passes it to the MultimodalEncodeWorker.
### Workflow
The MultimodalEncodeWorker is responsible for encoding the image and passing the embeddings to the MultimodalWorker via a combination of NATS and RDMA.
The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
Its MultimodalWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](../README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
MultimodalEncodeWorker independently from the prefill and decode workers if needed.
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings descriptor--> worker
worker --> encode_worker
```
```bash
cd $DYNAMO_HOME/components/backends/sglang
./launch/multimodal_agg.sh
```
### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image."
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 50,
"stream": false
}' | jq
```
You should see a response similar to this:
```json
{
"id": "chatcmpl-2546f44756884a14916ce13ebaa09da8",
"choices": [
{
"index": 0,
"message": {
"content": "This image shows a public transit bus on a dimly lit, street-level track in what appears to be a quiet urban neighborhood or suburban area. The bus displays \"OUT OF SERVICE\" in red on its illuminated sign. It is positioned",
"role": "assistant",
"reasoning_content": null
},
"finish_reason": "length"
}
],
"created": 1758824222,
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"object": "chat.completion",
"usage": {
"prompt_tokens": 0,
"completion_tokens": 40,
"total_tokens": 40
}
}
```
## Multimodal Disaggregated Serving
### Components
- workers: For disaggregated serving, we have three workers, [MultimodalEncodeWorker](src/dynamo/sglang/request_handlers/multimodal_encode_worker_handler.py) for encoding, [MultimodalWorker](src/dynamo/sglang/request_handlers/multimodal_worker_handler.py) for decoding, and [MultimodalPrefillWorker](src/dynamo/sglang/request_handlers/multimodal_worker_handler.py) for prefilling.
- processor: Tokenizes the prompt and passes it to the MultimodalEncodeWorker.
### Workflow
For the Qwen2.5-VL model, embeddings are only required during the prefill stage. As such, the image embeddings are transferred using a NIXL descriptor from the encode worker to the worker and then passed to the prefill worker for processing.
The prefill worker performs the prefilling step and forwards the KV cache to the worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](../README.md) example.
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
processor --> HTTP
processor --image_url--> encode_worker
encode_worker --> processor
encode_worker --embeddings descriptor--> worker
worker --> encode_worker
worker --embeddings descriptor--> prefill_worker
prefill_worker --> worker
```
```bash
cd $DYNAMO_HOME/components/backends/sglang
./launch/multimodal_disagg.sh
```
### Client
In another terminal:
```bash
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Describe the image."
},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
}
}
]
}
],
"max_tokens": 50,
"stream": false
}' | jq
```
You should see a response similar to this:
```json
{
"id": "chatcmpl-2546f44756884a14916ce13ebaa09da8",
"choices": [
{
"index": 0,
"message": {
"content": "This image shows a public transit bus on a dimly lit, street-level track in what appears to be a quiet urban neighborhood or suburban area. The bus displays \"OUT OF SERVICE\" in red on its illuminated sign. It is positioned",
"role": "assistant",
"reasoning_content": null
},
"finish_reason": "length"
}
],
"created": 1758824222,
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"object": "chat.completion",
"usage": {
"prompt_tokens": 0,
"completion_tokens": 40,
"total_tokens": 40
}
}
```
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
CHAT_TEMPLATE="qwen2-vl"
PROVIDED_CHAT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--chat-template)
PROVIDED_CHAT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --chat-template <template> Specify the SGLang chat template to use (default: $CHAT_TEMPLATE)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set CHAT_TEMPLATE if provided
if [[ -n "$PROVIDED_CHAT_TEMPLATE" ]]; then
CHAT_TEMPLATE="$PROVIDED_CHAT_TEMPLATE"
fi
# Get the directory where this script is located
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SGLANG_BACKEND_DIR="$SCRIPT_DIR/src"
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run SGLang multimodal processor
python3 -m dynamo.sglang --multimodal-processor --model-path "$MODEL_NAME" --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal encode worker
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal inference worker
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--multimodal-worker \
--model-path "$MODEL_NAME" \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-transfer-backend nixl &
# Wait for all background processes to complete
wait
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Default values
MODEL_NAME="Qwen/Qwen2.5-VL-7B-Instruct"
CHAT_TEMPLATE="qwen2-vl"
PROVIDED_CHAT_TEMPLATE=""
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL_NAME=$2
shift 2
;;
--chat-template)
PROVIDED_CHAT_TEMPLATE=$2
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --model <model_name> Specify the model to use (default: $MODEL_NAME)"
echo " --chat-template <template> Specify the SGLang chat template to use (default: $CHAT_TEMPLATE)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help for usage information"
exit 1
;;
esac
done
# Set CHAT_TEMPLATE if provided
if [[ -n "$PROVIDED_CHAT_TEMPLATE" ]]; then
CHAT_TEMPLATE="$PROVIDED_CHAT_TEMPLATE"
fi
# Get the directory where this script is located
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
SGLANG_BACKEND_DIR="$SCRIPT_DIR/src"
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run SGLang multimodal processor
python3 -m dynamo.sglang --multimodal-processor --model-path "$MODEL_NAME" --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal encode worker
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.sglang --multimodal-encode-worker --model-path "$MODEL_NAME" --chat-template "$CHAT_TEMPLATE" &
# run SGLang multimodal prefill worker
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang \
--multimodal-worker \
--model-path "$MODEL_NAME" \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode prefill \
--disaggregation-transfer-backend nixl &
# run SGLang multimodal decode worker
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.sglang \
--multimodal-worker \
--model-path "$MODEL_NAME" \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--disaggregation-mode decode \
--disaggregation-transfer-backend nixl &
# Wait for all background processes to complete
wait
......@@ -10,7 +10,7 @@ import sys
from argparse import Namespace
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from sglang.srt.server_args import ServerArgs
......@@ -61,6 +61,24 @@ DYNAMO_ARGS: Dict[str, Dict[str, Any]] = {
"default": False,
"help": "Use SGLang's tokenizer. This will skip tokenization of the input and output and only v1/chat/completions will be available when using the dynamo frontend. Cannot be used with --custom-jinja-template.",
},
"multimodal-processor": {
"flags": ["--multimodal-processor"],
"action": "store_true",
"default": False,
"help": "Run as multimodal processor component for handling multimodal requests",
},
"multimodal-encode-worker": {
"flags": ["--multimodal-encode-worker"],
"action": "store_true",
"default": False,
"help": "Run as multimodal encode worker component for processing images/videos",
},
"multimodal-worker": {
"flags": ["--multimodal-worker"],
"action": "store_true",
"default": False,
"help": "Run as multimodal worker component for LLM inference with multimodal data",
},
}
......@@ -79,6 +97,11 @@ class DynamoArgs:
# preprocessing options
use_sglang_tokenizer: bool = False
# multimodal options
multimodal_processor: bool = False
multimodal_encode_worker: bool = False
multimodal_worker: bool = False
class DisaggregationMode(Enum):
AGGREGATED = "agg"
......@@ -99,6 +122,8 @@ class Config:
return DisaggregationMode.PREFILL
elif self.server_args.disaggregation_mode == "decode":
return DisaggregationMode.DECODE
else:
return DisaggregationMode.AGGREGATED
def _set_parser(
......@@ -180,6 +205,15 @@ def parse_args(args: list[str]) -> Config:
and parsed_args.disaggregation_mode == "prefill"
):
endpoint = f"dyn://{namespace}.prefill.generate"
elif parsed_args.multimodal_processor:
endpoint = f"dyn://{namespace}.processor.generate"
elif parsed_args.multimodal_encode_worker:
endpoint = f"dyn://{namespace}.encoder.generate"
elif (
parsed_args.multimodal_worker
and parsed_args.disaggregation_mode == "prefill"
):
endpoint = f"dyn://{namespace}.prefill.generate"
else:
endpoint = f"dyn://{namespace}.backend.generate"
......@@ -231,6 +265,9 @@ def parse_args(args: list[str]) -> Config:
reasoning_parser=reasoning_parser,
custom_jinja_template=expanded_template_path,
use_sglang_tokenizer=parsed_args.use_sglang_tokenizer,
multimodal_processor=parsed_args.multimodal_processor,
multimodal_encode_worker=parsed_args.multimodal_encode_worker,
multimodal_worker=parsed_args.multimodal_worker,
)
logging.debug(f"Dynamo args: {dynamo_args}")
......@@ -264,6 +301,21 @@ def reserve_free_port(host: str = "localhost"):
sock.close()
def parse_endpoint(endpoint: str) -> List[str]:
"""Parse endpoint string into namespace, component, and endpoint parts."""
endpoint_str = endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
error_msg = (
f"Invalid endpoint format: '{endpoint}'. "
f"Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
logging.error(error_msg)
raise ValueError(error_msg)
return endpoint_parts
def _reserve_disaggregation_bootstrap_port():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
......
......@@ -11,17 +11,24 @@ import sglang as sgl
import uvloop
from sglang.srt.utils import get_ip
from dynamo.llm import ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.llm import ModelInput, ZmqKvEventPublisher, ZmqKvEventPublisherConfig
from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import Config, DisaggregationMode, parse_args
from dynamo.sglang.args import Config, DisaggregationMode, parse_args, parse_endpoint
from dynamo.sglang.health_check import (
SglangHealthCheckPayload,
SglangPrefillHealthCheckPayload,
)
from dynamo.sglang.publisher import setup_sgl_metrics
from dynamo.sglang.register import register_llm_with_runtime_config
from dynamo.sglang.request_handlers import DecodeWorkerHandler, PrefillWorkerHandler
from dynamo.sglang.request_handlers import (
DecodeWorkerHandler,
MultimodalEncodeWorkerHandler,
MultimodalPrefillWorkerHandler,
MultimodalProcessorHandler,
MultimodalWorkerHandler,
PrefillWorkerHandler,
)
configure_dynamo_logging()
......@@ -39,7 +46,16 @@ async def worker(runtime: DistributedRuntime):
logging.info("Signal handlers will trigger a graceful shutdown of the runtime")
config = parse_args(sys.argv[1:])
if config.serving_mode != DisaggregationMode.PREFILL:
if config.dynamo_args.multimodal_processor:
await init_multimodal_processor(runtime, config)
elif config.dynamo_args.multimodal_encode_worker:
await init_multimodal_encode_worker(runtime, config)
elif config.dynamo_args.multimodal_worker:
if config.serving_mode != DisaggregationMode.PREFILL:
await init_multimodal_worker(runtime, config)
else:
await init_multimodal_prefill_worker(runtime, config)
elif config.serving_mode != DisaggregationMode.PREFILL:
await init(runtime, config)
else:
await init_prefill(runtime, config)
......@@ -88,12 +104,6 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
async def gated_generate(request):
"""Queue requests until model registration completes"""
await ready_event.wait() # Block until model is ready
async for response in handler.generate(request):
yield response
handler = DecodeWorkerHandler(
component, engine, config, publisher, kv_publisher, prefill_client
)
......@@ -177,6 +187,202 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler.cleanup()
async def init_multimodal_processor(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal processor component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For processor, we need to connect to the encode worker
# Default endpoint for encode worker
encode_endpoint = f"dyn://{dynamo_args.namespace}.encoder.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
encode_endpoint
)
encode_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = MultimodalProcessorHandler(component, config, encode_worker_client)
logging.info("Waiting for Encoder Worker Instances ...")
await encode_worker_client.wait_for_instances()
async def register_model():
"""Register the model and signal readiness"""
registration_success = await register_llm_with_runtime_config(
None, # engine,
generate_endpoint,
server_args,
dynamo_args,
input_type=ModelInput.Text,
)
if not registration_success:
logging.error("Model registration failed; shutting down")
runtime.shutdown()
raise RuntimeError("Model registration failed")
logging.info("Model registration succeeded; processing queued requests")
try:
# Start endpoint immediately and register model concurrently
# Requests queue until ready_event is set
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
),
register_model(),
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal encode worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
# For encode worker, we need to connect to the downstream worker (LLM worker)
# Default endpoint for LLM worker
llm_endpoint = f"dyn://{dynamo_args.namespace}.backend.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
llm_endpoint
)
pd_worker_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = MultimodalEncodeWorkerHandler(component, config, pd_worker_client)
await handler.async_init(runtime)
await pd_worker_client.wait_for_instances()
tasks = [
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
)
]
try:
await asyncio.gather(*tasks)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal worker component for aggregated or decode mode"""
server_args, dynamo_args = config.server_args, config.dynamo_args
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
engine = sgl.Engine(server_args=server_args)
# Setup handler based on serving mode
if config.serving_mode == DisaggregationMode.DECODE:
# Decode mode: create prefill client
prefill_endpoint = f"dyn://{dynamo_args.namespace}.prefill.generate"
parsed_namespace, parsed_component_name, parsed_endpoint_name = parse_endpoint(
prefill_endpoint
)
logging.info("Initializing prefill client for multimodal decode worker")
prefill_client = (
await runtime.namespace(parsed_namespace)
.component(parsed_component_name)
.endpoint(parsed_endpoint_name)
.client()
)
handler = MultimodalWorkerHandler(
component, engine, config, None, None, prefill_client
)
else:
# Aggregated mode: no prefill client needed
handler = MultimodalWorkerHandler(component, engine, config)
# Initialize async components
await handler.async_init()
try:
await generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=[("model", server_args.served_model_name)],
graceful_shutdown=True,
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Config):
"""Initialize multimodal prefill worker component"""
server_args, dynamo_args = config.server_args, config.dynamo_args
engine = sgl.Engine(server_args=server_args)
component = runtime.namespace(dynamo_args.namespace).component(
dynamo_args.component
)
await component.create_service()
generate_endpoint = component.endpoint(dynamo_args.endpoint)
handler = MultimodalPrefillWorkerHandler(component, engine, config)
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
try:
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[("model", server_args.served_model_name)],
health_check_payload=health_check_payload,
)
)
except Exception as e:
logging.error(f"Failed to serve endpoints: {e}")
raise
finally:
handler.cleanup()
async def graceful_shutdown(runtime):
logging.info("Received shutdown signal, shutting down DistributedRuntime")
runtime.shutdown()
......
......@@ -17,6 +17,7 @@ async def register_llm_with_runtime_config(
endpoint: Endpoint,
server_args: ServerArgs,
dynamo_args: DynamoArgs,
input_type: Optional[ModelInput] = ModelInput.Tokens,
) -> bool:
"""Register LLM with runtime config
......@@ -24,7 +25,7 @@ async def register_llm_with_runtime_config(
bool: True if registration succeeded, False if it failed
"""
runtime_config = await _get_runtime_config(engine, server_args, dynamo_args)
input_type = ModelInput.Tokens
input_type = input_type
output_type = ModelType.Chat | ModelType.Completions
if not server_args.skip_tokenizer_init:
logging.warning(
......
......@@ -5,10 +5,22 @@ from .decode_handler import DecodeWorkerHandler
# Base handlers
from .handler_base import BaseWorkerHandler
# Multimodal handlers
from .multimodal_encode_worker_handler import MultimodalEncodeWorkerHandler
from .multimodal_processor_handler import MultimodalProcessorHandler
from .multimodal_worker_handler import (
MultimodalPrefillWorkerHandler,
MultimodalWorkerHandler,
)
from .prefill_handler import PrefillWorkerHandler
__all__ = [
"BaseWorkerHandler",
"DecodeWorkerHandler",
"PrefillWorkerHandler",
"MultimodalProcessorHandler",
"MultimodalEncodeWorkerHandler",
"MultimodalWorkerHandler",
"MultimodalPrefillWorkerHandler",
]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import AsyncIterator
import torch
from sglang.srt.conversation import chat_templates
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Component
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
from dynamo.sglang.utils.multimodal_encode_utils import encode_image_embeddings
from dynamo.sglang.utils.multimodal_image_loader import ImageLoader
from dynamo.sglang.utils.multimodal_protocol import SglangMultimodalRequest
logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
"""
Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker.
"""
def __init__(
self,
component: Component,
config: Config,
pd_worker_client: Client,
) -> None:
super().__init__(component, engine=None, config=config)
self.pd_worker_client = pd_worker_client
self.model = config.server_args.model_path
self.image_loader = ImageLoader(cache_size=CACHE_SIZE_MAXIMUM)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model, trust_remote_code=True
)
self.vision_model = AutoModel.from_pretrained(
self.model,
device_map="auto",
torch_dtype=torch.float16,
trust_remote_code=True,
)
# Load tokenizer to convert image token string to integer ID
self.tokenizer = AutoTokenizer.from_pretrained(
self.model, trust_remote_code=True
)
# Get image token string and handle it properly
image_token_str = (
chat_templates[getattr(config.server_args, "chat_template")]
.copy()
.image_token
)
# For Qwen2.5-VL, the image token might be multiple tokens
if image_token_str == "<|vision_start|><|image_pad|><|vision_end|>":
# These are likely the individual special tokens for Qwen2.5-VL
image_pad_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")
# Use the image_pad token as the main image token
self.image_token_id = image_pad_id
else:
# Fallback for other models
self.image_token_id = self.tokenizer.convert_tokens_to_ids(image_token_str)
self.min_workers = 1
def cleanup(self):
pass
async def generate(self, request: SglangMultimodalRequest) -> AsyncIterator[str]:
if not isinstance(request, SglangMultimodalRequest):
if isinstance(request, str):
request = SglangMultimodalRequest.model_validate_json(request)
else:
request = SglangMultimodalRequest.model_validate(request)
# The following steps encode the requested image for SGLang:
# 1. Open the image from the provided URL.
# 2. Process the image using the processor (which handles tokenization).
# 3. Extract input_ids and image data from processed result.
# 4. Run the image through the vision model to get precomputed embeddings.
# 5. Create SGLang-specific multimodal data format.
# 6. Create a descriptor for the embeddings and send to downstream worker.
try:
if not request.multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image = await self.image_loader.load_image(
request.multimodal_input.image_url
)
image_embeds = self.image_processor(images=image, return_tensors="pt")
precomputed_embeddings = encode_image_embeddings(
model_name=self.model,
image_embeds=image_embeds,
vision_encoder=self.vision_model,
projector=None,
)
image_grid_thw = (
image_embeds["image_grid_thw"].tolist()
if "image_grid_thw" in image_embeds
else None
)
# Store the image data info in the request for downstream
request.image_grid_thw = image_grid_thw
request.embeddings_shape = tuple(precomputed_embeddings.shape)
# Replace the single image token with multiple image tokens based on embedding shape
image_token_id_index = request.request.token_ids.index(self.image_token_id)
num_image_tokens = precomputed_embeddings.shape[
1
] # Number of image patches
# Replace single image token with multiple image tokens
request.request.token_ids = (
request.request.token_ids[:image_token_id_index]
+ [self.image_token_id] * num_image_tokens
+ request.request.token_ids[
image_token_id_index + 1 :
] # Skip the original token
)
# Create descriptor for the multimodal data
descriptor = connect.Descriptor(precomputed_embeddings)
with self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
await readable.wait_for_completion()
async for response in response_generator:
yield response.data() if hasattr(response, "data") else str(
response
)
except Exception as e:
logger.error(f"Error processing request: {e}")
raise
async def async_init(self, runtime: DistributedRuntime):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import time
import uuid
from typing import Any, Dict
from transformers import AutoTokenizer
from dynamo._core import Client, Component
from dynamo.sglang.args import Config
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
from dynamo.sglang.utils.multimodal_chat_processor import (
multimodal_request_to_sglang,
process_sglang_stream_response,
)
from dynamo.sglang.utils.multimodal_protocol import (
MultiModalInput,
MultiModalRequest,
SglangMultimodalRequest,
)
logger = logging.getLogger(__name__)
class MultimodalProcessorHandler(BaseWorkerHandler):
"""
Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker.
"""
def __init__(
self,
component: Component,
config: Config,
encode_worker_client: Client,
):
super().__init__(component, engine=None, config=config)
self.encode_worker_client = encode_worker_client
self.chat_template = getattr(config.server_args, "chat_template", "qwen2-vl")
self.model = config.server_args.model_path
# Initialize tokenizer for the model
self.tokenizer = AutoTokenizer.from_pretrained(
self.model,
trust_remote_code=True,
use_fast=True,
padding_side="left",
truncation_side="left",
)
def cleanup(self):
pass
async def generate(self, raw_request: MultiModalRequest):
if not isinstance(raw_request, MultiModalRequest):
# If the request is not MultiModalRequest, convert it to MultiModalRequest
raw_request = MultiModalRequest.model_validate(raw_request)
multimodal_input = MultiModalInput()
for message in raw_request.messages:
for item in message.content:
if item.type == "image_url":
multimodal_input.image_url = item.image_url.url
elif item.type == "video_url":
if multimodal_input.image_url is not None:
raise ValueError("Cannot provide both image and video URLs")
multimodal_input.video_url = item.video_url.url
if multimodal_input.image_url is None and multimodal_input.video_url is None:
raise ValueError("Either image URL or video URL is required")
async for response in self._generate(raw_request, multimodal_input):
logger.debug(
f"Generated response type {type(response)}, content: {response}"
)
yield response
async def _generate(
self,
raw_request: MultiModalRequest,
multimodal_input: MultiModalInput,
):
# Generate a unique request ID for tracking
request_id = str(uuid.uuid4().hex)
logger.debug(f"Got raw request: {raw_request}")
# Create SGLang conversation prompt
sglang_request = multimodal_request_to_sglang(
raw_request, self.tokenizer, self.chat_template
)
worker_request = SglangMultimodalRequest(
request=sglang_request,
multimodal_input=multimodal_input,
)
# Send to encoder worker
response_generator = await self.encode_worker_client.round_robin(
worker_request.model_dump_json()
)
# Process and yield SGLang responses
finished_sent = False
accumulated_text = ""
async for resp in response_generator:
try:
# Handle Annotated response objects from Dynamo (like vLLM pattern but for SGLang)
if hasattr(resp, "data"):
# Extract data from Dynamo Annotated response
raw_data = resp.data
if callable(raw_data):
raw_data = raw_data()
if isinstance(raw_data, str):
try:
response_data = json.loads(raw_data)
except json.JSONDecodeError:
response_data = {"text": raw_data, "finished": False}
else:
response_data = raw_data
elif isinstance(resp, str):
try:
response_data = json.loads(resp)
except json.JSONDecodeError:
response_data = {"text": resp, "finished": False}
else:
response_data = resp
# Use SGLang chat_processor for detokenization
(
text_content,
accumulated_text,
is_finished,
) = process_sglang_stream_response(
response_data, self.tokenizer, accumulated_text
)
# Create OpenAI-compatible response (following vLLM-like pattern but for SGLang)
if text_content or is_finished:
choice: Dict[str, Any] = {
"index": 0,
"delta": {},
"finish_reason": None,
}
delta: Dict[str, str] = choice["delta"] # Type-safe access
# Add role for first message or when there's content
if text_content and not finished_sent:
delta["role"] = "assistant"
# Add content if available
if text_content:
delta["content"] = text_content
# Set finish reason if completed
if is_finished:
choice["finish_reason"] = response_data.get(
"finish_reason", "stop"
)
if not finished_sent and not text_content:
# Final chunk needs role if it's the first chunk
delta["role"] = "assistant"
response_json = {
"id": f"chatcmpl-{request_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model,
"choices": [choice],
}
# Add usage only for final response
if is_finished:
response_json["usage"] = {
"prompt_tokens": 0,
"completion_tokens": len(accumulated_text.split())
if accumulated_text
else 0,
"total_tokens": len(accumulated_text.split())
if accumulated_text
else 0,
}
yield response_json
if is_finished:
finished_sent = True
break
except Exception as e:
logger.error(f"Error processing SGLang response: {e}")
error_response = {
"id": f"chatcmpl-{request_id}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": self.model,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": f"Error: {str(e)}",
},
"finish_reason": "stop",
}
],
}
yield error_response
break
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from sglang.srt.conversation import chat_templates
logger = logging.getLogger(__name__)
def clean_addcriterion(text: str) -> str:
"""
Removes the addCriterion prefix from the text output.
To prevent addCriterion from appearing in outputs, an assistant placeholder must be added to the conversation.
However, adding the assistant placeholder causes subsequent requests to fail with shape mismatch errors on the engine side.
The root cause is still under investigation, so this temporary workaround is in place to maintain functionality.
"""
if text.startswith(" addCriterion"):
cleaned_text = text[13:].lstrip() # 12 = len(" addCriterion")
logger.debug(
f"🛠️ HACK: Removed ' addCriterion' prefix: '{text[:20]}...' -> '{cleaned_text[:20]}...'"
)
return cleaned_text
if text.startswith("addCriterion"):
cleaned_text = text[12:].lstrip() # 11 = len("addCriterion")
logger.debug(
f"🛠️ HACK: Removed 'addCriterion' prefix: '{text[:20]}...' -> '{cleaned_text[:20]}...'"
)
return cleaned_text
return text
def multimodal_request_to_sglang(raw_request, tokenizer, chat_template):
conv = chat_templates[chat_template].copy()
conv.messages = []
# Convert messages into SGLang conversation
for msg in raw_request.messages:
if msg.role == "system":
conv.system_message = msg.content
elif msg.role == "user":
text_parts = []
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
elif part.type == "image_url":
text_parts.append(conv.image_token)
conv.append_message(conv.roles[0], " ".join(text_parts))
elif msg.role == "assistant":
conv.append_message(conv.roles[1], msg.content)
logger.debug(f"conv: {conv}")
# Tokenize and prepare input_ids
processed = tokenizer(text=conv.get_prompt(), return_tensors="pt")
input_ids = processed["input_ids"][0].tolist()
# Build the SGLang request dict
sglang_request = {
"model": raw_request.model,
"token_ids": input_ids,
"batch_token_ids": None,
"stop_conditions": {"max_tokens": raw_request.max_tokens or None},
"sampling_options": {"temperature": raw_request.temperature or 0.7},
"eos_token_ids": [tokenizer.eos_token_id],
"annotations": [],
"stream": raw_request.stream if raw_request.stream is not None else False,
}
return sglang_request
def detokenize_sglang_response(response_data, tokenizer):
"""
Detokenize SGLang response token IDs to text.
Args:
response_data: Dictionary containing token_ids and other response data
tokenizer: The tokenizer to use for detokenization
Returns:
String containing the detokenized text, empty string if no tokens
"""
try:
# Handle Annotated objects from Dynamo (following vLLM-like pattern)
if hasattr(response_data, "data"):
try:
import json
raw_data = response_data.data
# Handle callable data method
if callable(raw_data):
raw_data = raw_data()
response_data = (
json.loads(raw_data) if isinstance(raw_data, str) else raw_data
)
except (json.JSONDecodeError, AttributeError):
try:
raw_data = response_data.data
if callable(raw_data):
raw_data = raw_data()
response_data = {"text": str(raw_data), "finished": False}
except Exception:
response_data = {"text": str(response_data), "finished": False}
# Ensure response_data is a dictionary
if not isinstance(response_data, dict):
return clean_addcriterion(str(response_data))
# Get text content - detokenize if needed
if "text" in response_data and response_data["text"]:
return clean_addcriterion(response_data["text"])
elif "token_ids" in response_data and response_data["token_ids"]:
token_ids = response_data["token_ids"]
if isinstance(token_ids, list) and token_ids:
# Detokenize token IDs to get text
text_content = tokenizer.decode(token_ids, skip_special_tokens=True)
logger.debug(
f"Detokenized {len(token_ids)} tokens to: '{text_content}'"
)
return clean_addcriterion(text_content)
# Return empty string if no content to detokenize
return ""
except Exception as e:
logger.error(f"Failed to detokenize response: {e}")
return f"[Detokenization error: {e}]"
def process_sglang_stream_response(response_data, tokenizer, accumulated_text=""):
"""
Process a single SGLang streaming response with efficient detokenization.
Args:
response_data: Dictionary containing SGLang response data
tokenizer: The tokenizer to use for detokenization
accumulated_text: Previously accumulated text for context
Returns:
Tuple of (text_content, updated_accumulated_text, is_finished)
"""
try:
# Handle Annotated objects from Dynamo (following vLLM-like pattern)
if hasattr(response_data, "data"):
try:
import json
raw_data = response_data.data
# Handle callable data method
if callable(raw_data):
raw_data = raw_data()
response_data = (
json.loads(raw_data) if isinstance(raw_data, str) else raw_data
)
except (json.JSONDecodeError, AttributeError):
try:
raw_data = response_data.data
if callable(raw_data):
raw_data = raw_data()
response_data = {"text": str(raw_data), "finished": False}
except Exception:
response_data = {"text": str(response_data), "finished": False}
# Ensure response_data is a dictionary
if not isinstance(response_data, dict):
response_data = {"text": str(response_data), "finished": False}
# Detokenize the current response
text_content = detokenize_sglang_response(response_data, tokenizer)
# Update accumulated text
new_accumulated = accumulated_text + text_content
# Check if this is the final response
is_finished = response_data.get("finished", False) or response_data.get(
"finish_reason"
)
return text_content, new_accumulated, is_finished
except Exception as e:
logger.error(f"Error processing SGLang stream response: {e}")
return f"[Processing error: {e}]", accumulated_text, True
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, Optional
import torch
logger = logging.getLogger(__name__)
class SupportedModels:
"""Supported multimodal model identifiers"""
QWEN_2_5_VL_7B = "Qwen/Qwen2.5-VL-7B-Instruct"
def get_qwen_image_features(
vision_encoder: torch.nn.Module, image_embeds: Dict[str, Any]
) -> torch.Tensor:
"""
Extract image features using Qwen-style vision encoder.
Args:
vision_encoder: The vision encoder model
image_embeds: Dictionary containing pixel values and grid information
Returns:
Processed image features tensor
Raises:
ValueError: If grid_thw is not provided for Qwen model
"""
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
grid_thw = image_embeds.get("image_grid_thw", None)
if grid_thw is not None:
grid_thw = grid_thw.to(vision_encoder.device)
logger.debug(f"Qwen grid_thw shape: {grid_thw.shape}")
else:
raise ValueError("grid_thw is not provided")
return (
vision_encoder.get_image_features(pixel_values, grid_thw) # type: ignore
if grid_thw is not None
else vision_encoder.get_image_features(pixel_values) # type: ignore
)
def encode_image_embeddings(
model_name: str,
image_embeds: Dict[str, Any],
vision_encoder: torch.nn.Module,
projector: Optional[torch.nn.Module] = None,
) -> torch.Tensor:
"""
Encode image embeddings using the appropriate model-specific encoder.
Args:
model_name: The model identifier
image_embeds: Dictionary containing processed image data
vision_encoder: The vision encoder module
projector: The multimodal projector (required for LLaVA-style models)
Returns:
Encoded embeddings tensor with normalized shape
Raises:
ValueError: If projector is missing for LLaVA models
NotImplementedError: If model is not supported
"""
with torch.no_grad():
# Route through the correct encoder based on model
if model_name == SupportedModels.QWEN_2_5_VL_7B:
embeddings = get_qwen_image_features(vision_encoder, image_embeds)
else:
raise NotImplementedError(f"Model not supported: {model_name}")
# Normalize output shape
if isinstance(embeddings, (tuple, list)):
embeddings = embeddings[0]
embeddings = embeddings.unsqueeze(0) if embeddings.ndim == 2 else embeddings
return embeddings
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import base64
import binascii
import logging
from io import BytesIO
from typing import Optional
from urllib.parse import urlparse
import httpx
from PIL import Image
logger = logging.getLogger(__name__)
# Global HTTP client instance
_global_http_client: Optional[httpx.AsyncClient] = None
def get_http_client(timeout: float = 60.0) -> httpx.AsyncClient:
"""
Get or create a shared HTTP client instance.
Args:
timeout: Timeout for HTTP requests
Returns:
Shared HTTP client instance
"""
global _global_http_client
if _global_http_client is None or _global_http_client.is_closed:
_global_http_client = httpx.AsyncClient(
timeout=timeout,
follow_redirects=True,
limits=httpx.Limits(max_keepalive_connections=20, max_connections=100),
)
logger.info(f"Shared HTTP client initialized with timeout={timeout}s")
return _global_http_client
class ImageLoader:
CACHE_SIZE_MAXIMUM = 8
def __init__(
self, cache_size: int = CACHE_SIZE_MAXIMUM, http_timeout: float = 30.0
):
self._http_timeout = http_timeout
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: asyncio.Queue[str] = asyncio.Queue(maxsize=cache_size)
async def load_image(self, image_url: str) -> Image.Image:
parsed_url = urlparse(image_url)
# For HTTP(S) URLs, check cache first
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
if image_url_lower in self._image_cache:
logger.debug(f"Image found in cache for URL: {image_url}")
return self._image_cache[image_url_lower]
try:
if parsed_url.scheme == "data":
# Parse data URL format: data:[<media type>][;base64],<data>
if not parsed_url.path.startswith("image/"):
raise ValueError("Data URL must be an image type")
# Split the path into media type and data
media_type, data = parsed_url.path.split(",", 1)
if ";base64" not in media_type:
raise ValueError("Data URL must be base64 encoded")
try:
image_bytes = base64.b64decode(data)
image_data = BytesIO(image_bytes)
except binascii.Error as e:
raise ValueError(f"Invalid base64 encoding: {e}")
elif parsed_url.scheme in ("http", "https"):
http_client = get_http_client(self._http_timeout)
response = await http_client.get(image_url)
response.raise_for_status()
if not response.content:
raise ValueError("Empty response content from image URL")
image_data = BytesIO(response.content)
else:
raise ValueError(f"Invalid image source scheme: {parsed_url.scheme}")
# PIL is sync, so offload to a thread to avoid blocking the event loop
image = await asyncio.to_thread(Image.open, image_data)
# Validate image format and convert to RGB
if image.format not in ("JPEG", "PNG", "WEBP"):
raise ValueError(f"Unsupported image format: {image.format}")
image_converted = image.convert("RGB")
# Cache HTTP(S) URLs
if parsed_url.scheme in ("http", "https"):
image_url_lower = image_url.lower()
# Cache the image for future use, and evict the oldest image if the cache is full
if self._cache_queue.full():
oldest_image_url = await self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[image_url_lower] = image_converted
await self._cache_queue.put(image_url_lower)
return image_converted
except httpx.HTTPError as e:
logger.error(f"HTTP error loading image: {e}")
raise
except Exception as e:
logger.error(f"Error loading image: {e}")
raise ValueError(f"Failed to load image: {e}")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field
import dynamo.nixl_connect as connect
from dynamo.sglang.protocol import PreprocessedRequest
TokenIdType = int
class TextContent(BaseModel):
type: Literal["text"]
text: str
class ImageURLDetail(BaseModel):
url: str
class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURLDetail
class VideoURLDetail(BaseModel):
url: str
class VideoContent(BaseModel):
type: Literal["video_url"]
video_url: VideoURLDetail
MessageContent = Union[TextContent, ImageContent, VideoContent]
class ChatMessage(BaseModel):
role: Literal["user", "system", "assistant"]
content: List[MessageContent]
class MultiModalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = None
temperature: Optional[float] = None
stream: Optional[bool] = False
class MultiModalInput(BaseModel):
image_url: Optional[str] = None
video_url: Optional[str] = None
class SglangMultimodalRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
request: PreprocessedRequest
multimodal_input: Optional[MultiModalInput] = Field(default_factory=MultiModalInput)
image_grid_thw: Optional[List[Any]] = None
embeddings_shape: Optional[
Union[Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
class DisaggSglangMultimodalRequest(BaseModel):
request: SglangMultimodalRequest
sampling_params: dict
data_parallel_rank: Optional[int] = None
......@@ -39,16 +39,16 @@ git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
### Workflow
In this graph, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py).
In this workflow, we have two workers, [VllmEncodeWorker](components/encode_worker.py) and [VllmPDWorker](components/worker.py).
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the VllmPDWorker via a combination of NATS and RDMA.
The work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](../../components/backends/vllm/README.md) example.
By separating the encode from the prefill and decode stages, we can have a more flexible deployment and scale the
VllmEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......@@ -115,16 +115,16 @@ You should see a response similar to this:
- processor: Tokenizes the prompt and passes it to the VllmEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
### Workflow
In this graph, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
In this workflow, we have three workers, [VllmEncodeWorker](components/encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the Llava model, embeddings are only required during the prefill stage. As such, the VllmEncodeWorker is connected directly to the prefill worker.
The VllmEncodeWorker is responsible for encoding the image and passing the embeddings to the prefill worker via a combination of NATS and RDMA.
Its work complete event is sent via NATS, while the embeddings tensor is transferred via RDMA through the NIXL interface.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](../../components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......@@ -201,11 +201,11 @@ of the model per node.
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
#### Workflow
In this graph, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
In this workflow, we have [VllmPDWorker](components/worker.py) which will encode the image, prefill and decode the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......@@ -263,13 +263,13 @@ You should see a response similar to this:
- processor: Tokenizes the prompt and passes it to the VllmPDWorker.
- frontend: HTTP endpoint to handle incoming requests.
#### Graph
#### Workflow
In this graph, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
In this workflow, we have two workers, [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
The prefill worker performs the encoding and prefilling steps and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......@@ -337,16 +337,16 @@ This example demonstrates deploying an aggregated multimodal model that can proc
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
### Workflow
In this graph, we have two workers, [VideoEncodeWorker](components/video_encode_worker.py) and [VllmPDWorker](components/worker.py).
In this workflow, we have two workers, [VideoEncodeWorker](components/video_encode_worker.py) and [VllmPDWorker](components/worker.py).
The VideoEncodeWorker is responsible for decoding the video into a series of frames. Unlike the image pipeline which generates embeddings,
this pipeline passes the raw frames directly to the VllmPDWorker via a combination of NATS and RDMA.
Its VllmPDWorker then prefills and decodes the prompt, just like the [LLM aggregated serving](/components/backends/vllm/README.md) example.
By separating the video processing from the prefill and decode stages, we can have a more flexible deployment and scale the
VideoEncodeWorker independently from the prefill and decode workers if needed.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......@@ -425,15 +425,15 @@ This example demonstrates deploying a disaggregated multimodal model that can pr
- processor: Tokenizes the prompt and passes it to the VideoEncodeWorker.
- frontend: HTTP endpoint to handle incoming requests.
### Graph
### Workflow
In this graph, we have three workers, [VideoEncodeWorker](components/video_encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
In this workflow, we have three workers, [VideoEncodeWorker](components/video_encode_worker.py), [VllmDecodeWorker](components/worker.py), and [VllmPDWorker](components/worker.py).
For the LLaVA-NeXT-Video-7B model, frames are only required during the prefill stage. As such, the VideoEncodeWorker is connected directly to the prefill worker.
The VideoEncodeWorker is responsible for decoding the video into a series of frames and passing them to the prefill worker via RDMA.
The prefill worker performs the prefilling step and forwards the KV cache to the decode worker for decoding.
For more details on the roles of the prefill and decode workers, refer to the [LLM disaggregated serving](/components/backends/vllm/README.md) example.
This figure shows the flow of the graph:
This figure illustrates the workflow:
```mermaid
flowchart LR
HTTP --> processor
......
......@@ -13,7 +13,11 @@ from tests.serve.common import (
run_serve_deployment,
)
from tests.utils.engine_process import EngineConfig
from tests.utils.payload_builder import chat_payload_default, completion_payload_default
from tests.utils.payload_builder import (
chat_payload,
chat_payload_default,
completion_payload_default,
)
logger = logging.getLogger(__name__)
......@@ -86,6 +90,32 @@ sglang_configs = {
)
],
),
"multimodal_agg_qwen": SGLangConfig(
name="multimodal_agg_qwen",
directory=sglang_dir,
script_name="multimodal_agg.sh",
marks=[pytest.mark.gpu_2],
model="Qwen/Qwen2.5-VL-7B-Instruct",
delayed_start=0,
timeout=360,
models_port=8000,
request_payloads=[
chat_payload(
[
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "http://images.cocodataset.org/test2017/000000155781.jpg"
},
},
],
repeat_count=1,
expected_response=["bus"],
temperature=0.0,
)
],
),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment