Unverified Commit 70892fc1 authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

feat: add NixlEmbeddingSender / NixlEmbeddingReceiver in sglang backend (#7153)


Signed-off-by: default avatarWang, Yi <yi.a.wang@intel.com>
Co-authored-by: default avatarKrishnanPrash <140860868+KrishnanPrash@users.noreply.github.com>
parent 17abc9de
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Multimodal utilities for Dynamo components.""" """Multimodal utilities for Dynamo components."""
from dynamo.common.constants import EmbeddingTransferMode
from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache from dynamo.common.multimodal.async_encoder_cache import AsyncEncoderCache
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
...@@ -15,8 +16,24 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -15,8 +16,24 @@ from dynamo.common.multimodal.embedding_transfer import (
) )
from dynamo.common.multimodal.image_loader import ImageLoader from dynamo.common.multimodal.image_loader import ImageLoader
EMBEDDING_SENDER_FACTORIES = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingSender,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingSender,
EmbeddingTransferMode.NIXL_READ: NixlReadEmbeddingSender,
}
EMBEDDING_RECEIVER_FACTORIES = {
EmbeddingTransferMode.LOCAL: LocalEmbeddingReceiver,
EmbeddingTransferMode.NIXL_WRITE: NixlWriteEmbeddingReceiver,
# [gluo FIXME] can't use pre-registered tensor as NIXL requires descriptors
# to be at matching size, need to overwrite nixl connect library
EmbeddingTransferMode.NIXL_READ: lambda: NixlReadEmbeddingReceiver(max_items=0),
}
__all__ = [ __all__ = [
"AsyncEncoderCache", "AsyncEncoderCache",
"EMBEDDING_RECEIVER_FACTORIES",
"EMBEDDING_SENDER_FACTORIES",
"ImageLoader", "ImageLoader",
"NixlReadEmbeddingReceiver", "NixlReadEmbeddingReceiver",
"NixlReadEmbeddingSender", "NixlReadEmbeddingSender",
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
"""Dynamo SGLang wrapper configuration ArgGroup.""" """Dynamo SGLang wrapper configuration ArgGroup."""
import argparse import argparse
from typing import Optional from typing import Optional, Union
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.constants import EmbeddingTransferMode
from . import __version__ from . import __version__
...@@ -62,6 +63,15 @@ class DynamoSGLangArgGroup(ArgGroup): ...@@ -62,6 +63,15 @@ class DynamoSGLangArgGroup(ArgGroup):
help="Run as multimodal worker component for LLM inference with multimodal data.", help="Run as multimodal worker component for LLM inference with multimodal data.",
) )
add_argument(
g,
flag_name="--embedding-transfer-mode",
env_var="DYN_SGL_EMBEDDING_TRANSFER_MODE",
default=EmbeddingTransferMode.NIXL_WRITE.value,
help="Worker embedding transfer mode: 'local', 'nixl-write', or 'nixl-read'. Can also be set with environment variable DYN_SGL_EMBEDDING_TRANSFER_MODE.",
choices=[m.value for m in EmbeddingTransferMode],
)
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--embedding-worker", flag_name="--embedding-worker",
...@@ -107,6 +117,7 @@ class DynamoSGLangConfig(ConfigBase): ...@@ -107,6 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor: bool multimodal_processor: bool
multimodal_encode_worker: bool multimodal_encode_worker: bool
multimodal_worker: bool multimodal_worker: bool
embedding_transfer_mode: Union[str, EmbeddingTransferMode]
embedding_worker: bool embedding_worker: bool
image_diffusion_worker: bool image_diffusion_worker: bool
...@@ -116,6 +127,10 @@ class DynamoSGLangConfig(ConfigBase): ...@@ -116,6 +127,10 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker: bool video_generation_worker: bool
def validate(self) -> None: def validate(self) -> None:
if isinstance(self.embedding_transfer_mode, str):
self.embedding_transfer_mode = EmbeddingTransferMode(
self.embedding_transfer_mode
)
if (self.disagg_config is not None) ^ (self.disagg_config_key is not None): if (self.disagg_config is not None) ^ (self.disagg_config_key is not None):
raise ValueError( raise ValueError(
"Both 'disagg_config' and 'disagg_config_key' must be provided together." "Both 'disagg_config' and 'disagg_config_key' must be provided together."
......
...@@ -102,7 +102,6 @@ async def init_multimodal_encode_worker( ...@@ -102,7 +102,6 @@ async def init_multimodal_encode_worker(
).client() ).client()
handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event) handler = MultimodalEncodeWorkerHandler(config, pd_worker_client, shutdown_event)
await handler.async_init(runtime)
await pd_worker_client.wait_for_instances() await pd_worker_client.wait_for_instances()
...@@ -159,8 +158,6 @@ async def init_multimodal_worker( ...@@ -159,8 +158,6 @@ async def init_multimodal_worker(
else: else:
handler = MultimodalWorkerHandler(engine, config, None, shutdown_event) handler = MultimodalWorkerHandler(engine, config, None, shutdown_event)
await handler.async_init()
if config.serving_mode == DisaggregationMode.DECODE: if config.serving_mode == DisaggregationMode.DECODE:
health_check_payload = SglangDisaggHealthCheckPayload(engine).to_dict() health_check_payload = SglangDisaggHealthCheckPayload(engine).to_dict()
else: else:
...@@ -203,8 +200,6 @@ async def init_multimodal_prefill_worker( ...@@ -203,8 +200,6 @@ async def init_multimodal_prefill_worker(
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
await handler.async_init()
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
try: try:
......
...@@ -6,7 +6,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union ...@@ -6,7 +6,7 @@ from typing import Any, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
import dynamo.nixl_connect as connect from dynamo.common.multimodal import TransferRequest
TokenIdType = int TokenIdType = int
...@@ -129,7 +129,7 @@ class SglangMultimodalRequest(BaseModel): ...@@ -129,7 +129,7 @@ class SglangMultimodalRequest(BaseModel):
embeddings_shape: Optional[ embeddings_shape: Optional[
Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]] Union[Tuple[int, int], Tuple[int, int, int], Tuple[int, int, int, int]]
] = None ] = None
serialized_request: Optional[connect.RdmaMetadata] = None transfer_payload: Optional[TransferRequest] = None
class DisaggSglangMultimodalRequest(BaseModel): class DisaggSglangMultimodalRequest(BaseModel):
......
...@@ -15,10 +15,9 @@ except (ImportError, OSError): ...@@ -15,10 +15,9 @@ except (ImportError, OSError):
from sglang.srt.parser.conversation import chat_templates from sglang.srt.parser.conversation import chat_templates
from transformers import AutoTokenizer from transformers import AutoTokenizer
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context from dynamo._core import Client, Context
from dynamo.common.multimodal import EMBEDDING_SENDER_FACTORIES
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.protocol import SglangMultimodalRequest from dynamo.sglang.protocol import SglangMultimodalRequest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...@@ -94,6 +93,16 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -94,6 +93,16 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
self.min_workers = 1 self.min_workers = 1
sender = EMBEDDING_SENDER_FACTORIES.get(
config.dynamo_args.embedding_transfer_mode
)
if sender is None:
raise ValueError(
"Invalid embedding transfer mode: "
f"{config.dynamo_args.embedding_transfer_mode}"
)
self.embedding_sender = sender()
def cleanup(self) -> None: def cleanup(self) -> None:
pass pass
...@@ -223,9 +232,9 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -223,9 +232,9 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
mm_group.image_grid_thw = image_grid_thw mm_group.image_grid_thw = image_grid_thw
mm_group.multimodal_input.image_url = None mm_group.multimodal_input.image_url = None
# Store shared serialized tensor metadata at request level. # Store shared tensor transfer metadata at request level.
request.embeddings_shape = tuple(precomputed_embeddings.shape) request.embeddings_shape = tuple(precomputed_embeddings.shape)
request.serialized_request = None request.transfer_payload = None
search_start = 0 search_start = 0
for num_image_tokens in token_counts: for num_image_tokens in token_counts:
...@@ -245,31 +254,24 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler): ...@@ -245,31 +254,24 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
) )
search_start = image_token_id_index + num_image_tokens search_start = image_token_id_index + num_image_tokens
descriptor = connect.Descriptor(precomputed_embeddings) with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
with await self._connector.create_readable(descriptor) as readable: (
request.serialized_request = readable.metadata() transfer_request,
transfer_future,
) = await self.embedding_sender.send_embeddings(precomputed_embeddings)
request.transfer_payload = transfer_request
logger.debug(f"Request: {request.model_dump_json()}") logger.debug(f"Request: {request.model_dump_json()}")
# Get the response generator from downstream worker # Get the response generator from downstream worker
response_generator = await self.pd_worker_client.round_robin( response_generator = await self.pd_worker_client.round_robin(
request.model_dump_json() request.model_dump_json()
) )
with _nvtx.annotate("mm:enc:embedding_transfer", color="purple"):
await readable.wait_for_completion()
async for response in response_generator: async for response in response_generator:
yield response.data() if hasattr(response, "data") else str( yield response.data() if hasattr(response, "data") else str(response)
response
) await transfer_future
except Exception as e: except Exception as e:
logger.error(f"Error processing request: {e}") logger.error(f"Error processing request: {e}")
raise raise
async def async_init(self, runtime: DistributedRuntime) -> None:
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
logger.info("Startup completed.")
...@@ -9,9 +9,9 @@ from typing import Any, AsyncIterator, Callable, Optional ...@@ -9,9 +9,9 @@ from typing import Any, AsyncIterator, Callable, Optional
import sglang as sgl import sglang as sgl
import torch import torch
import dynamo.nixl_connect as connect
from dynamo._core import Client, Context from dynamo._core import Client, Context
from dynamo.common.constants import DisaggregationMode from dynamo.common.constants import DisaggregationMode, EmbeddingTransferMode
from dynamo.common.multimodal import EMBEDDING_RECEIVER_FACTORIES, TransferRequest
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
...@@ -74,14 +74,17 @@ class SglangUtils: ...@@ -74,14 +74,17 @@ class SglangUtils:
class EmbeddingsProcessor: class EmbeddingsProcessor:
"""Handles multimodal embeddings processing and multimodal item creation""" """Handles multimodal embeddings processing and multimodal item creation"""
def __init__(self): def __init__(self, embedding_transfer_mode: EmbeddingTransferMode):
self._connector = None receiver = EMBEDDING_RECEIVER_FACTORIES.get(embedding_transfer_mode)
if receiver is None:
async def initialize(self): raise ValueError(
"""Initialize the connector for embeddings processing""" f"Invalid embedding transfer mode: {embedding_transfer_mode}"
self._connector = connect.Connector() )
self.embedding_receiver = receiver()
async def process_embeddings(self, request: SglangMultimodalRequest): async def process_embeddings(
self, request: SglangMultimodalRequest
) -> tuple[torch.Tensor, int]:
"""Process one concatenated embedding tensor from serialized request.""" """Process one concatenated embedding tensor from serialized request."""
logger.debug("Processing embeddings with shape: " f"{request.embeddings_shape}") logger.debug("Processing embeddings with shape: " f"{request.embeddings_shape}")
...@@ -89,37 +92,26 @@ class EmbeddingsProcessor: ...@@ -89,37 +92,26 @@ class EmbeddingsProcessor:
if not multimodal_groups: if not multimodal_groups:
raise ValueError("multimodal_inputs is required") raise ValueError("multimodal_inputs is required")
serialized_request = request.serialized_request transfer_request = request.transfer_payload
embeddings_shape = request.embeddings_shape if transfer_request is None:
if serialized_request is None: raise ValueError("transfer_payload is required on request")
raise ValueError("serialized_request is required on request")
if embeddings_shape is None: if not isinstance(transfer_request, TransferRequest):
raise ValueError("embeddings_shape is required on request") transfer_request = TransferRequest.model_validate(transfer_request)
embeddings_shape = request.embeddings_shape or tuple(
transfer_request.embeddings_shape
)
if len(embeddings_shape) < 2: if len(embeddings_shape) < 2:
raise ValueError(f"Invalid embeddings shape: {embeddings_shape}") raise ValueError(f"Invalid embeddings shape: {embeddings_shape}")
embeddings = torch.empty( tensor_id, embeddings = await self.embedding_receiver.receive_embeddings(
embeddings_shape, transfer_request
dtype=MultimodalConfig.EMBEDDINGS_DTYPE,
device=MultimodalConfig.EMBEDDINGS_DEVICE,
) )
return embeddings, tensor_id
descriptor = connect.Descriptor(embeddings) def release_embeddings(self, tensor_id: int) -> None:
if descriptor is None: self.embedding_receiver.release_tensor(tensor_id)
raise RuntimeError("Descriptor is None - cannot process embeddings")
if self._connector is None:
logger.warning(
"Connector is None - this should not happen after initialization"
)
self._connector = connect.Connector()
with _nvtx.annotate("mm:nixl:begin_read", color="blue"):
read_op = await self._connector.begin_read(serialized_request, descriptor)
with _nvtx.annotate("mm:nixl:wait_completion", color="cyan"):
await read_op.wait_for_completion()
return embeddings, descriptor
@staticmethod @staticmethod
def create_multimodal_item(embeddings: torch.Tensor, image_grid_thw) -> dict: def create_multimodal_item(embeddings: torch.Tensor, image_grid_thw) -> dict:
...@@ -249,9 +241,9 @@ class ErrorResponseBuilder: ...@@ -249,9 +241,9 @@ class ErrorResponseBuilder:
async def _build_mm_items( async def _build_mm_items(
request: SglangMultimodalRequest, embeddings_processor: EmbeddingsProcessor request: SglangMultimodalRequest, embeddings_processor: EmbeddingsProcessor
) -> tuple[list[dict], torch.Tensor]: ) -> tuple[list[dict], torch.Tensor, int]:
"""Process embeddings and build a single multimodal item for SGLang.""" """Process embeddings and build a single multimodal item for SGLang."""
embeddings, _ = await embeddings_processor.process_embeddings(request) embeddings, tensor_id = await embeddings_processor.process_embeddings(request)
image_grid_thw_list = [group.image_grid_thw for group in request.multimodal_inputs] image_grid_thw_list = [group.image_grid_thw for group in request.multimodal_inputs]
if any(item is None for item in image_grid_thw_list): if any(item is None for item in image_grid_thw_list):
...@@ -261,7 +253,7 @@ async def _build_mm_items( ...@@ -261,7 +253,7 @@ async def _build_mm_items(
embeddings_processor.create_multimodal_item(embeddings, image_grid_thw_list) embeddings_processor.create_multimodal_item(embeddings, image_grid_thw_list)
] ]
return mm_items, embeddings return mm_items, embeddings, tensor_id
class MultimodalWorkerHandler(BaseWorkerHandler): class MultimodalWorkerHandler(BaseWorkerHandler):
...@@ -280,7 +272,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -280,7 +272,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, None, None, shutdown_event) super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors # Initialize processors
self.embeddings_processor = EmbeddingsProcessor() self.embeddings_processor = EmbeddingsProcessor(
config.dynamo_args.embedding_transfer_mode
)
# Store serving mode and prefill client (like regular SGLang) # Store serving mode and prefill client (like regular SGLang)
self.serving_mode = config.serving_mode self.serving_mode = config.serving_mode
...@@ -296,10 +290,6 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -296,10 +290,6 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
else: else:
logger.info("Multimodal aggregated worker handler initialized") logger.info("Multimodal aggregated worker handler initialized")
async def async_init(self):
"""Initialize async components"""
await self.embeddings_processor.initialize()
def _validate_and_parse_request(self, request) -> SglangMultimodalRequest: def _validate_and_parse_request(self, request) -> SglangMultimodalRequest:
"""Validate and parse incoming request""" """Validate and parse incoming request"""
if type(request) is not SglangMultimodalRequest: if type(request) is not SglangMultimodalRequest:
...@@ -408,11 +398,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -408,11 +398,11 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
input_ids = request.request.token_ids input_ids = request.request.token_ids
if not input_ids: if not input_ids:
raise ValueError("input_ids is required") raise ValueError("input_ids is required")
tensor_id: int | None = None
try: try:
sampling_params = SglangUtils.build_sampling_params(request) sampling_params = SglangUtils.build_sampling_params(request)
with _nvtx.annotate("mm:pd:load_multimodal", color="cyan"): with _nvtx.annotate("mm:pd:load_multimodal", color="cyan"):
mm_items, combined_embeddings = await _build_mm_items( mm_items, combined_embeddings, tensor_id = await _build_mm_items(
request, self.embeddings_processor request, self.embeddings_processor
) )
...@@ -434,6 +424,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -434,6 +424,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
try: try:
async for output in StreamProcessor.process_sglang_stream(agg_stream): async for output in StreamProcessor.process_sglang_stream(agg_stream):
if first_token: if first_token:
if tensor_id is not None:
self.embeddings_processor.release_embeddings(tensor_id)
tensor_id = None
end_ttft() end_ttft()
_nvtx.end_range(rng_first) _nvtx.end_range(rng_first)
first_token = False first_token = False
...@@ -461,6 +454,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler): ...@@ -461,6 +454,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg)) yield ErrorResponseBuilder.build_error_response(RuntimeError(error_msg))
else: else:
yield ErrorResponseBuilder.build_error_response(e) yield ErrorResponseBuilder.build_error_response(e)
finally:
if tensor_id is not None:
self.embeddings_processor.release_embeddings(tensor_id)
async def _get_bootstrap_from_prefill( async def _get_bootstrap_from_prefill(
self, request: SglangMultimodalRequest, sampling_params: dict self, request: SglangMultimodalRequest, sampling_params: dict
...@@ -509,7 +505,9 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -509,7 +505,9 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
super().__init__(engine, config, None, None, shutdown_event) super().__init__(engine, config, None, None, shutdown_event)
# Initialize processors # Initialize processors
self.embeddings_processor = EmbeddingsProcessor() self.embeddings_processor = EmbeddingsProcessor(
config.dynamo_args.embedding_transfer_mode
)
# Get bootstrap info using BootstrapManager # Get bootstrap info using BootstrapManager
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(engine) self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(engine)
...@@ -518,10 +516,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -518,10 +516,6 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
f"Multimodal prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}" f"Multimodal prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
) )
async def async_init(self):
"""Initialize async components like connector"""
await self.embeddings_processor.initialize()
async def generate( async def generate(
self, disagg_request: DisaggSglangMultimodalRequest, context: Context self, disagg_request: DisaggSglangMultimodalRequest, context: Context
) -> AsyncIterator[str]: ) -> AsyncIterator[str]:
...@@ -592,10 +586,13 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -592,10 +586,13 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
request = disagg_request.request request = disagg_request.request
input_ids = request.request.token_ids input_ids = request.request.token_ids
sampling_params = disagg_request.sampling_params sampling_params = disagg_request.sampling_params
tensor_id: int | None = None
# Process embeddings from encode worker using our embeddings processor # Process embeddings from encode worker using our embeddings processor
with _nvtx.annotate("mm:prefill:load_multimodal", color="cyan"): with _nvtx.annotate("mm:prefill:load_multimodal", color="cyan"):
mm_items, _ = await _build_mm_items(request, self.embeddings_processor) mm_items, _, tensor_id = await _build_mm_items(
request, self.embeddings_processor
)
# Start SGLang prefill generation (like regular SGLang) # Start SGLang prefill generation (like regular SGLang)
with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"): with _nvtx.annotate("mm:prefill:engine_async_generate", color="blue"):
...@@ -610,12 +607,19 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler): ...@@ -610,12 +607,19 @@ class MultimodalPrefillWorkerHandler(BaseWorkerHandler):
) )
# Consume results without yielding (prefill doesn't return text, just coordinates) # Consume results without yielding (prefill doesn't return text, just coordinates)
asyncio.create_task(self._consume_results(results)) asyncio.create_task(self._consume_results(results, tensor_id))
async def _consume_results(self, results): async def _consume_results(self, results, tensor_id: int):
"""Consume prefill results without returning them (like regular SGLang)""" """Consume prefill results without returning them (like regular SGLang)"""
async for _ in results: released = False
pass try:
async for _ in results:
if not released:
self.embeddings_processor.release_embeddings(tensor_id)
released = True
finally:
if not released:
self.embeddings_processor.release_embeddings(tensor_id)
def cleanup(self): def cleanup(self):
super().cleanup() super().cleanup()
......
...@@ -182,6 +182,8 @@ sglang_configs = { ...@@ -182,6 +182,8 @@ sglang_configs = {
"DYN_WORKER_GPU": "0", "DYN_WORKER_GPU": "0",
"DYN_ENCODE_GPU_MEM": "0.1", "DYN_ENCODE_GPU_MEM": "0.1",
"DYN_WORKER_GPU_MEM": "0.4", "DYN_WORKER_GPU_MEM": "0.4",
# FIXME: NIXL Agent Initialization (shared memory interface) causes segfault
"UCX_TLS": "^mm",
}, },
frontend_port=DefaultPort.FRONTEND.value, frontend_port=DefaultPort.FRONTEND.value,
request_payloads=[ request_payloads=[
...@@ -218,7 +220,10 @@ sglang_configs = { ...@@ -218,7 +220,10 @@ sglang_configs = {
model="Qwen/Qwen3-VL-2B-Instruct", model="Qwen/Qwen3-VL-2B-Instruct",
script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"], script_args=["--model", "Qwen/Qwen3-VL-2B-Instruct", "--single-gpu"],
timeout=360, timeout=360,
env={}, env={
# FIXME: NIXL Agent Initialization (shared memory interface) causes segfault
"UCX_TLS": "^mm",
},
frontend_port=DefaultPort.FRONTEND.value, frontend_port=DefaultPort.FRONTEND.value,
request_payloads=[ request_payloads=[
chat_payload( chat_payload(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment