"lib/llm/vscode:/vscode.git/clone" did not exist on "2be83be2d8b6236860bfb0611ea5782fba5255c4"
Unverified Commit 70892fc1 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: add NixlEmbeddingSender / NixlEmbeddingReceiver in sglang backend (#7153)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
Co-authored-by: default avatarKrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com>
parent 17abc9de
......@@ -3,6 +3,7 @@
"""Multimodal utilities for Dynamo components."""
from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver,
......@@ -15,8 +16,24 @@ from dynamo.common.multimodal.embedding_transfer import (
)
from dynamo.common.multimodal.image_loader import ImageLoader
EMBEDDING_SENDER_FACTORIES = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingSender,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingSender,
EmbeddingTransferMode.NIXL_READ: NixlReadEmbeddingSender,
}
EMBEDDING_RECEIVER_FACTORIES = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingReceiver,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingReceiver,
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
EmbeddingTransferMode.NIXL_READ: lambda: NixlReadEmbeddingReceiver(max_items=0),
}
__all__ = [
"AsyncEncoderCache",
"EMBEDDING_RECEIVER_FACTORIES",
"EMBEDDING_SENDER_FACTORIES",
"ImageLoader",
"NixlReadEmbeddingReceiver",
"NixlReadEmbeddingSender",
......
......@@ -4,11 +4,12 @@
"""Dynamo SGLang wrapper configuration ArgGroup."""
import argparse
from typing import Optional
from typing import Optional, Union
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.constants import EmbeddingTransferMode
from . import __version__
......@@ -62,6 +63,15 @@ class DynamoSGLangArgGroup(ArgGroup):
help="Run as multimodal worker component for LLM inference with multimodal data.",
)
add_argument(
g,
flag_name="--embedding-transfer-mode",
env_var="DYN_SGL_EMBEDDING_TRANSFER_MODE",
default=EmbeddingTransferMode.NIXL_WRITE.value,
help="Worker embedding transfer mode: 'local', 'nixl-write', or 'nixl-read'. Can also be set with environment variable DYN_SGL_EMBEDDING_TRANSFER_MODE.",
choices=[m.value for m in EmbeddingTransferMode],
)
add_negatable_bool_argument(
g,
flag_name="--embedding-worker",
......@@ -107,6 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor: bool
multimodal_encode_worker: bool
multimodal_worker: bool
embedding_transfer_mode: Union[str, EmbeddingTransferMode]
embedding_worker: bool
image_diffusion_worker: bool
......@@ -116,6 +127,10 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker: bool
def validate(self) -> None:
if isinstance(self.embedding_transfer_mode, str):
self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode
)
if (self.disagg_config is not None) ^ (self.disagg_config_key is not None):
raise ValueError(
"Both 'disagg_config' and 'disagg_config_key' must be provided together."
......
......@@ -102,7 +102,6 @@ async def init_multimodal_encode_worker(
).client()
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
await handler.async_init(runtime)
await pd_worker_client.wait_for_instances()
......@@ -159,8 +158,6 @@ async def init_multimodal_worker(
else:
handler = MultimodalWorkerHandler(engine, config, None, shutdown_event)
await handler.async_init()
if config.serving_mode == DisaggregationMode.DECODE:
health_check_payload = SglangDisaggHealthCheckPayload(engine).to_dict()
else:
......@@ -203,8 +200,6 @@ async def init_multimodal_prefill_worker(
shutdown_endpoints[:] = [generate_endpoint]
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
try:
......
......@@ -6,7 +6,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
import dynamo.nixl_connect as connect
from dynamo.common.multimodal import TransferRequest
TokenIdType = int
......@@ -129,7 +129,7 @@ class SglangMultimodalRequest(BaseModel):
embeddings_shape: Optional[
Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None
serialized_request: Optional[connect.RdmaMetadata] = None
transfer_payload: Optional[TransferRequest] = None
class DisaggSglangMultimodalRequest(BaseModel):
......
......@@ -15,10 +15,9 @@ except (ImportError, OSError):
from sglang.srt.parser.conversation import chat_templates
from transformers import AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context
from dynamo.common.multimodal import EMBEDDING_SENDER_FACTORIES
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
......@@ -94,6 +93,16 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
self.min_workers = 1
sender = EMBEDDING_SENDER_FACTORIES.get(
config.dynamo_args.embedding_transfer_mode
)
if sender is None:
raise ValueError(
"Invalid embedding transfer mode: "
f"{config.dynamo_args.embedding_transfer_mode}"
)
self.embedding_sender = sender()
def cleanup(self) -> None:
pass
......@@ -223,9 +232,9 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
mm_group.image_grid_thw = image_grid_thw
mm_group.multimodal_input.image_url = None
# Store shared serialized tensor metadata at request level.
# Store shared tensor transfer metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.serialized_request = None
request.transfer_payload = None
search_start = 0
for num_image_tokens in token_counts:
......@@ -245,31 +254,24 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
)
search_start = image_token_id_index + num_image_tokens
descriptor = connect.Descriptor(precomputed_embeddings)
with await self._connector.create_readable(descriptor) as readable:
request.serialized_request = readable.metadata()
with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
(
transfer_request,
transfer_future,
) = await self.embedding_sender.send_embeddings(precomputed_embeddings)
request.transfer_payload = transfer_request
logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
await readable.wait_for_completion()
# Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json()
)
async for response in response_generator:
yield response.data() if hasattr(response, "data") else str(
response
)
async for response in response_generator:
yield response.data() if hasattr(response, "data") else str(response)
await transfer_future
except Exception as e:
logger.error(f"Error processing request: {e}")
raise
async def async_init(self, runtime: DistributedRuntime) -> None:
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
logger.info("Startup completed.")
......@@ -9,9 +9,9 @@ from typing import Any, AsyncIterator, Callable, Optional
import sglang as sgl
import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode
from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
from dynamo.common.multimodal import EMBEDDING_RECEIVER_FACTORIES, TransferRequest
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config
......@@ -74,14 +74,17 @@ class SglangUtils:
class EmbeddingsProcessor:
"""Handles multimodal embeddings processing and multimodal item creation"""
def __init__(self):
self._connector = None
async def initialize(self):
"""Initialize the connector for embeddings processing"""
self._connector = connect.Connector()
def __init__(self, embedding_transfer_mode: EmbeddingTransferMode):
receiver = EMBEDDING_RECEIVER_FACTORIES.get(embedding_transfer_mode)
if receiver is None:
raise ValueError(
f"Invalid embedding transfer mode: {embedding_transfer_mode}"
)
self.embedding_receiver = receiver()
async def process_embeddings(self, request: SglangMultimodalRequest):
async def process_embeddings(
self, request: SglangMultimodalRequest
) -> tuple[torch.Tensor, int]:
"""Process one concatenated embedding tensor from serialized request."""
logger.debug("Processing embeddings with shape: " f"{request.embeddings_shape}")
......@@ -89,37 +92,26 @@ class EmbeddingsProcessor:
if not multimodal_groups:
raise ValueError("multimodal_inputs is required")
serialized_request = request.serialized_request
embeddings_shape = request.embeddings_shape
if serialized_request is None:
raise ValueError("serialized_request is required on request")
if embeddings_shape is None:
raise ValueError("embeddings_shape is required on request")
transfer_request = request.transfer_payload
if transfer_request is None:
raise ValueError("transfer_payload is required on request")
if not isinstance(transfer_request, TransferRequest):
transfer_request = TransferRequest.model_validate(transfer_request)
embeddings_shape = request.embeddings_shape or tuple(
transfer_request.embeddings_shape
)
if len(embeddings_shape) < 2:
raise ValueError(f"Invalid embeddings shape: {embeddings_shape}")
embeddings = torch.empty(
embeddings_shape,
dtype=MultimodalConfig.EMBEDDINGS_DTYPE,
device=MultimodalConfig.EMBEDDINGS_DEVICE,
tensor_id, embeddings = await self.embedding_receiver.receive_embeddings(
transfer_request
)
return embeddings, tensor_id
descriptor = connect.Descriptor(embeddings)
if descriptor is None:
raise RuntimeError("Descriptor is None - cannot process embeddings")
if self._connector is None:
logger.warning(
"Connector is None - this should not happen after initialization"
)
self._connector = connect.Connector()
with _nvtx.annotate("mm:nixl:begin_read", color="blue"):
read_op = await self._connector.begin_read(serialized_request, descriptor)
with _nvtx.annotate("mm:nixl:wait_completion", color="cyan"):
await read_op.wait_for_completion()
return embeddings, descriptor
def release_embeddings(self, tensor_id: int) -> None:
self.embedding_receiver.release_tensor(tensor_id)
@staticmethod
def create_multimodal_item(embeddings: torch.Tensor, image_grid_thw) -> dict:
......@@ -249,9 +241,9 @@ class ErrorResponseBuilder:
async def _build_mm_items(
request: SglangMultimodalRequest, embeddings_processor: EmbeddingsProcessor
) -> tuple[list[dict], torch.Tensor]:
) -> tuple[list[dict], torch.Tensor, int]:
"""Process embeddings and build a single multimodal item for SGLang."""
embeddings, _ = await embeddings_processor.process_embeddings(request)
embeddings, tensor_id = await embeddings_processor.process_embeddings(request)
image_grid_thw_list = [group.image_grid_thw for group in request.multimodal_inputs]
if any(item is None for item in image_grid_thw_list):
......@@ -261,7 +253,7 @@ async def _build_mm_items(
embeddings_processor.create_multimodal_item(embeddings, image_grid_thw_list)
]
return mm_items, embeddings
return mm_items, embeddings, tensor_id
class MultimodalWorkerHandler(BaseWorkerHandler):
......@@ -280,7 +272,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
self.embeddings_processor = EmbeddingsProcessor(
config.dynamo_args.embedding_transfer_mode
)
# Store serving mode and prefill client (like regular SGLang)
self.serving_mode = config.serving_mode
......@@ -296,10 +290,6 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
else:
logger.info("Multimodal aggregated worker handler initialized")
async def async_init(self):
"""Initialize async components"""
await self.embeddings_processor.initialize()
def _validate_and_parse_request(self, request) -> SglangMultimodalRequest:
"""Validate and parse incoming request"""
if type(request) is not SglangMultimodalRequest:
......@@ -408,11 +398,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
input_ids = request.request.token_ids
if not input_ids:
raise ValueError("input_ids is required")
tensor_id: int | None = None
try:
sampling_params = SglangUtils.build_sampling_params(request)
with _nvtx.annotate("mm:pd:load_multimodal", color="cyan"):
mm_items, combined_embeddings = await _build_mm_items(
mm_items, combined_embeddings, tensor_id = await _build_mm_items(
request, self.embeddings_processor
)
......@@ -434,6 +424,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try:
async for output in StreamProcessor.process_sglang_stream(agg_stream):
if first_token:
if tensor_id is not None:
self.embeddings_processor.release_embeddings(tensor_id)
tensor_id = None
end_ttft()
_nvtx.end_range(rng_first)
first_token = False
......@@ -461,6 +454,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg))
else:
yield ErrorResponseBuilder.build_error_response(e)
finally:
if tensor_id is not None:
self.embeddings_processor.release_embeddings(tensor_id)
async def _get_bootstrap_from_prefill(
self, request: SglangMultimodalRequest, sampling_params: dict
......@@ -509,7 +505,9 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors
self.embeddings_processor = EmbeddingsProcessor()
self.embeddings_processor = EmbeddingsProcessor(
config.dynamo_args.embedding_transfer_mode
)
# Get bootstrap info using BootstrapManager
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(engine)
......@@ -518,10 +516,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
f"Multimodal prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
)
async def async_init(self):
"""Initialize async components like connector"""
await self.embeddings_processor.initialize()
async def generate(
self, disagg_request: DisaggSglangMultimodalRequest, context: Context
) -> AsyncIterator[str]:
......@@ -592,10 +586,13 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
request = disagg_request.request
input_ids = request.request.token_ids
sampling_params = disagg_request.sampling_params
tensor_id: int | None = None
# Process embeddings from encode worker using our embeddings processor
with _nvtx.annotate("mm:prefill:load_multimodal", color="cyan"):
mm_items, _ = await _build_mm_items(request, self.embeddings_processor)
mm_items, _, tensor_id = await _build_mm_items(
request, self.embeddings_processor
)
# Start SGLang prefill generation (like regular SGLang)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
......@@ -610,12 +607,19 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
)
# Consume results without yielding (prefill doesn't return text, just coordinates)
asyncio.create_task(self._consume_results(results))
asyncio.create_task(self._consume_results(results, tensor_id))
async def _consume_results(self, results):
async def _consume_results(self, results, tensor_id: int):
"""Consume prefill results without returning them (like regular SGLang)"""
async for _ in results:
pass
released = False
try:
async for _ in results:
if not released:
self.embeddings_processor.release_embeddings(tensor_id)
released = True
finally:
if not released:
self.embeddings_processor.release_embeddings(tensor_id)
def cleanup(self):
super().cleanup()
......
......@@ -182,6 +182,8 @@ sglang_configs = {
"DYN_WORKER_GPU": "0",
"DYN_ENCODE_GPU_MEM": "0.1",
"DYN_WORKER_GPU_MEM": "0.4",
# FIXME: NIXL Agent Initialization (shared memory interface) causes segfault
"UCX_TLS": "^mm",
},
frontend_port=DefaultPort.FRONTEND.value,
request_payloads=[
......@@ -218,7 +220,10 @@ sglang_configs = {
model="Qwen/Qwen3-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
timeout=360,
env={},
env={
# FIXME: NIXL Agent Initialization (shared memory interface) causes segfault
"UCX_TLS": "^mm",
},
frontend_port=DefaultPort.FRONTEND.value,
request_payloads=[
chat_payload(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment