Unverified Commit c30ebb93 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[VLM] Optimize async mm data process mechanism (#12066)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 41efcaeb
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any, Dict, List, Optional, Union
logger = logging.getLogger(__name__)
class AsyncMMDataProcessor:
"""
Async wrapper for a multimodal processor.
Behavior:
- If the underlying processor exposes `process_mm_data_async`, call/await it directly.
- Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool.
- Optionally guard per-call concurrency via an asyncio.Semaphore.
- Optionally enforce per-call timeout via asyncio.wait_for.
"""
def __init__(
self,
mm_processor: Any,
*,
max_concurrent_calls: Optional[int] = None,
timeout_s: Optional[float] = None,
) -> None:
"""
Args:
mm_processor: An object exposing either
- async def process_mm_data_async(...): -> Dict[str, Any]
or
- def process_mm_data(...): -> Dict[str, Any]
max_concurrent_calls: Optional concurrency cap for per-call execution.
timeout_s: Optional timeout (seconds) for each `process()` call.
"""
self.mm_processor = mm_processor
self.timeout_s = timeout_s
# Concurrency guard (None -> unlimited)
self.semaphore = (
asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None
)
# Detect async path; if missing, prepare a fallback executor for sync path
self._proc_async = getattr(mm_processor, "process_mm_data_async", None)
self.is_async = asyncio.iscoroutinefunction(self._proc_async)
self.fallback_exec: Optional[ThreadPoolExecutor] = (
ThreadPoolExecutor(max_workers=max_concurrent_calls)
if not self.is_async
else None
)
async def process(
self,
*,
image_data: Optional[List[Union[str, bytes]]] = None,
audio_data: Optional[List[Union[str, bytes]]] = None,
input_text_or_ids: Union[str, List[int], None] = None,
request_obj: Any,
**kwargs: Any,
) -> Dict[str, Any]:
"""
Public entrypoint: process a single multimodal request without blocking the event loop.
"""
async def _invoke() -> Dict[str, Any]:
if self.is_async:
# Native async implementation
return await self._proc_async(
image_data=image_data,
audio_data=audio_data,
input_text=input_text_or_ids,
request_obj=request_obj,
**kwargs,
)
# Synchronous fallback
sync_fn = getattr(self.mm_processor, "process_mm_data", None)
if not callable(sync_fn):
raise RuntimeError(
"mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'."
)
loop = asyncio.get_running_loop()
fn = partial(
sync_fn,
image_data=image_data,
audio_data=audio_data,
input_text=input_text_or_ids,
request_obj=request_obj,
**kwargs,
)
return await loop.run_in_executor(self.fallback_exec, fn)
# Apply optional concurrency guard
if self.semaphore is not None:
async with self.semaphore:
if self.timeout_s is not None:
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
return await _invoke()
# No concurrency guard
if self.timeout_s is not None:
return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
return await _invoke()
def shutdown(self) -> None:
"""Gracefully shutdown resources owned by this wrapper."""
try:
if self.fallback_exec:
self.fallback_exec.shutdown(wait=False)
except Exception:
logger.exception(
"Error while shutting down fallback executor in AsyncMMDataProcessor"
)
def __del__(self):
# Best-effort shutdown
try:
self.shutdown()
except Exception:
pass
...@@ -43,6 +43,7 @@ from sglang.srt.configs.model_config import ModelConfig ...@@ -43,6 +43,7 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.lora.lora_registry import LoRARegistry from sglang.srt.lora.lora_registry import LoRARegistry
from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
from sglang.srt.managers.disagg_service import start_disagg_service from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
...@@ -215,6 +216,11 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -215,6 +216,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.mm_processor = get_mm_processor( self.mm_processor = get_mm_processor(
self.model_config.hf_config, server_args, _processor, transport_mode self.model_config.hf_config, server_args, _processor, transport_mode
) )
self.mm_data_processor = AsyncMMDataProcessor(
self.mm_processor,
max_concurrent_calls=self.server_args.mm_max_concurrent_calls,
timeout_s=self.server_args.mm_per_request_timeout,
)
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None self.tokenizer = self.processor = None
...@@ -598,10 +604,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -598,10 +604,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
obj.image_data = [obj.image_data] obj.image_data = [obj.image_data]
if obj.audio_data is not None and not isinstance(obj.audio_data, list): if obj.audio_data is not None and not isinstance(obj.audio_data, list):
obj.audio_data = [obj.audio_data] obj.audio_data = [obj.audio_data]
mm_inputs: Dict = await self.mm_processor.process_mm_data_async( mm_inputs: Dict = await self.mm_data_processor.process(
image_data=obj.image_data, image_data=obj.image_data,
audio_data=obj.audio_data, audio_data=obj.audio_data,
input_text=input_text or input_ids, input_text_or_ids=(input_text or input_ids),
request_obj=obj, request_obj=obj,
max_req_input_len=self.max_req_input_len, max_req_input_len=self.max_req_input_len,
) )
......
...@@ -542,6 +542,10 @@ class ServerArgs: ...@@ -542,6 +542,10 @@ class ServerArgs:
pdmux_config_path: Optional[str] = None pdmux_config_path: Optional[str] = None
sm_group_num: int = 8 sm_group_num: int = 8
# For Multi-Modal
mm_max_concurrent_calls: int = 32
mm_per_request_timeout: float = 10.0
def __post_init__(self): def __post_init__(self):
""" """
Orchestrates the handling of various server arguments, ensuring proper configuration and validation. Orchestrates the handling of various server arguments, ensuring proper configuration and validation.
...@@ -3519,6 +3523,20 @@ class ServerArgs: ...@@ -3519,6 +3523,20 @@ class ServerArgs:
help="Read CLI options from a config file. Must be a YAML file with configuration options.", help="Read CLI options from a config file. Must be a YAML file with configuration options.",
) )
# For Multi-Modal
parser.add_argument(
"--mm-max-concurrent-calls",
type=int,
default=ServerArgs.mm_max_concurrent_calls,
help="The max concurrent calls for async mm data processing.",
)
parser.add_argument(
"--mm-per-request-timeout",
type=int,
default=ServerArgs.mm_per_request_timeout,
help="The timeout for each multi-modal request in seconds.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
"""
Unit tests for AsyncMMDataProcessor.
Covers:
- Async and sync processing paths
- Concurrency limiting via semaphore
- Per-call timeout behavior (async and sync)
- Argument passthrough (images, audios, text/ids, request_obj, kwargs)
- Error propagation and shutdown behavior
"""
import asyncio
import logging
import threading
import time
from unittest.mock import Mock
import pytest
from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor
class TestAsyncMMDataProcessor:
"""Test suite for AsyncMMDataProcessor."""
@pytest.fixture
def async_processor(self):
"""Create a processor exposing an async process_mm_data_async."""
class AsyncProc:
async def process_mm_data_async(
self,
*,
image_data=None,
audio_data=None,
input_text=None,
request_obj=None,
**kwargs,
):
# Allow tests to simulate latency via kwargs
delay = kwargs.get("delay_s", 0.0)
if delay:
await asyncio.sleep(delay)
return {
"path": "async",
"images": image_data,
"audios": audio_data,
"text": input_text,
"request": request_obj,
"kwargs": kwargs,
}
return AsyncProc()
@pytest.fixture
def sync_processor(self):
"""Provide a processor exposing a sync process_mm_data."""
class SyncProc:
def process_mm_data(
self,
*,
image_data=None,
audio_data=None,
input_text=None,
request_obj=None,
**kwargs,
):
delay = kwargs.get("delay_s", 0.0)
if delay:
# Simulate CPU/blocking work
time.sleep(delay)
return {
"path": "sync",
"images": image_data,
"audios": audio_data,
"text": input_text,
"request": request_obj,
"kwargs": kwargs,
}
return SyncProc()
@pytest.mark.asyncio
async def test_async_path_basic(self, async_processor):
"""Async processor should be awaited directly."""
proc = AsyncMMDataProcessor(async_processor)
out = await proc.process(
image_data=["img1.png"],
audio_data=["a.wav"],
input_text_or_ids="hello",
request_obj={"rid": 1},
mode="fast",
)
assert out["path"] == "async"
assert out["images"] == ["img1.png"]
assert out["audios"] == ["a.wav"]
assert out["text"] == "hello"
assert out["request"] == {"rid": 1}
assert out["kwargs"]["mode"] == "fast"
@pytest.mark.asyncio
async def test_sync_fallback_basic(self, sync_processor):
"""Sync processor should run in fallback executor."""
proc = AsyncMMDataProcessor(sync_processor)
out = await proc.process(
image_data=[b"\x00\x01"],
audio_data=None,
input_text_or_ids=[1, 2, 3],
request_obj="req-obj",
role="user",
)
assert out["path"] == "sync"
assert out["images"] == [b"\x00\x01"]
assert out["audios"] is None
assert out["text"] == [1, 2, 3]
assert out["request"] == "req-obj"
assert out["kwargs"]["role"] == "user"
@pytest.mark.asyncio
async def test_timeout_async(self, async_processor):
"""Timeout should raise asyncio.TimeoutError for async path."""
proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01)
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow",
request_obj=None,
delay_s=0.05, # longer than timeout
)
@pytest.mark.asyncio
async def test_timeout_sync(self, sync_processor):
"""Timeout should raise asyncio.TimeoutError for sync fallback path."""
proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01)
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow",
request_obj=None,
delay_s=0.05, # longer than timeout
)
@pytest.mark.asyncio
async def test_semaphore_release_after_timeout(self, sync_processor):
"""
If a call times out, the semaphore should be released so a subsequent call can proceed.
Use >=2 fallback workers so the timed-out thread doesn't block the next call.
"""
proc = AsyncMMDataProcessor(
sync_processor,
max_concurrent_calls=2,
timeout_s=0.01,
)
# First call will time out
with pytest.raises(asyncio.TimeoutError):
await proc.process(
input_text_or_ids="slow1", request_obj=None, delay_s=0.05
)
# Second call should be able to acquire the semaphore and complete
out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0)
assert out["text"] == "ok"
@pytest.mark.asyncio
async def test_concurrency_limit_async(self):
"""Ensure max_concurrent_calls caps concurrency for async path."""
current = 0
max_seen = 0
class AsyncProc:
async def process_mm_data_async(self, **kwargs):
nonlocal current, max_seen
current += 1
max_seen = max(max_seen, current)
try:
await asyncio.sleep(0.02)
return {"ok": True}
finally:
current -= 1
proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2)
tasks = [
proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6)
]
await asyncio.gather(*tasks)
assert max_seen <= 2
@pytest.mark.asyncio
async def test_concurrency_limit_sync(self):
"""Ensure max_concurrent_calls caps concurrency for sync fallback path."""
current = 0
max_seen = 0
lock = threading.Lock()
class SyncProc:
def process_mm_data(self, **kwargs):
nonlocal current, max_seen
with lock:
current += 1
max_seen = max(max_seen, current)
try:
time.sleep(0.02)
return {"ok": True}
finally:
with lock:
current -= 1
proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3)
tasks = [
proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9)
]
await asyncio.gather(*tasks)
assert max_seen <= 3
@pytest.mark.asyncio
async def test_error_from_async_processor(self):
"""Exceptions raised by the async processor should propagate."""
class BadAsync:
async def process_mm_data_async(self, **_):
await asyncio.sleep(0)
raise ValueError("async boom")
proc = AsyncMMDataProcessor(BadAsync())
with pytest.raises(ValueError, match="async boom"):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_error_from_sync_processor(self):
"""Exceptions raised by the sync processor should propagate."""
class BadSync:
def process_mm_data(self, **_):
raise RuntimeError("sync boom")
proc = AsyncMMDataProcessor(BadSync())
with pytest.raises(RuntimeError, match="sync boom"):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_missing_both_methods_raises(self):
"""Processor missing both methods should raise at call time."""
class Empty:
pass
proc = AsyncMMDataProcessor(Empty())
with pytest.raises(
RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'"
):
await proc.process(input_text_or_ids="x", request_obj=None)
@pytest.mark.asyncio
async def test_async_attribute_not_coroutine_uses_sync_fallback(self):
"""
If `process_mm_data_async` exists but isn't a coroutine function,
wrapper should treat it as sync and use `process_mm_data`.
"""
class WeirdProc:
# Not a coroutine function:
def process_mm_data_async(self, **_):
return {"path": "would-be-async"}
def process_mm_data(self, **_):
return {"path": "sync"}
proc = AsyncMMDataProcessor(WeirdProc())
out = await proc.process(input_text_or_ids="x", request_obj=None)
assert out["path"] == "sync"
@pytest.mark.asyncio
async def test_kwargs_and_request_passthrough_async(self, async_processor):
"""Extra kwargs and request_obj should be forwarded on async path."""
proc = AsyncMMDataProcessor(async_processor)
out = await proc.process(
image_data=["i1", "i2"],
audio_data=["a1"],
input_text_or_ids="hello world",
request_obj={"uid": 42},
return_meta=True,
delay_s=0.0,
)
assert out["images"] == ["i1", "i2"]
assert out["audios"] == ["a1"]
assert out["text"] == "hello world"
assert out["request"] == {"uid": 42}
assert out["kwargs"]["return_meta"] is True
@pytest.mark.asyncio
async def test_kwargs_and_request_passthrough_sync(self, sync_processor):
"""Extra kwargs and request_obj should be forwarded on sync path."""
proc = AsyncMMDataProcessor(sync_processor)
out = await proc.process(
image_data=None,
audio_data=[],
input_text_or_ids=[101, 102],
request_obj=("r", 7),
lang="en",
)
assert out["images"] is None
assert out["audios"] == []
assert out["text"] == [101, 102]
assert out["request"] == ("r", 7)
assert out["kwargs"]["lang"] == "en"
def test_shutdown_on_sync_executor(self, sync_processor):
"""Explicit shutdown should close fallback executor for sync path."""
proc = AsyncMMDataProcessor(sync_processor)
# Swap real executor for a mock to assert shutdown behavior
proc.fallback_exec = Mock()
proc.shutdown()
proc.fallback_exec.shutdown.assert_called_once_with(wait=False)
def test_del_calls_shutdown(self, sync_processor, caplog):
"""__del__ should best-effort shutdown without raising."""
caplog.set_level(logging.DEBUG)
proc = AsyncMMDataProcessor(sync_processor)
proc.fallback_exec = Mock()
# Simulate object destruction
proc.__del__()
proc.fallback_exec.shutdown.assert_called_once_with(wait=False)
@pytest.mark.asyncio
async def test_concurrent_mixed_requests(self, async_processor):
"""Mix different payloads and ensure all complete with valid outputs."""
proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4)
tasks = [
proc.process(input_text_or_ids="t1", request_obj=1),
proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2),
proc.process(
audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3
),
proc.process(
image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4
),
]
outs = await asyncio.gather(*tasks)
assert len(outs) == 4
for out in outs:
assert "path" in out
assert out["path"] == "async"
@pytest.mark.asyncio
async def test_many_requests_values_match_inputs(self, sync_processor):
"""For sync path, ensure each response corresponds to its specific input."""
proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8)
texts = [f"msg-{i}" for i in range(10)]
tasks = [
proc.process(input_text_or_ids=t, request_obj=i)
for i, t in enumerate(texts)
]
outs = await asyncio.gather(*tasks)
got = [o["text"] for o in outs]
assert got == texts
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment