test_async_mm_data_processor.py 12.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
"""
Unit tests for AsyncMMDataProcessor.

Covers:
  - Async and sync processing paths
  - Concurrency limiting via semaphore
  - Per-call timeout behavior (async and sync)
  - Argument passthrough (images, audios, text/ids, request_obj, kwargs)
  - Error propagation and shutdown behavior
"""

import asyncio
import logging
import threading
import time
from unittest.mock import Mock

import pytest

from sglang.srt.managers.async_mm_data_processor import AsyncMMDataProcessor


class TestAsyncMMDataProcessor:
    """Test suite for AsyncMMDataProcessor."""

    @pytest.fixture
    def async_processor(self):
        """Create a processor exposing an async process_mm_data_async."""

        class AsyncProc:
            async def process_mm_data_async(
                self,
                *,
                image_data=None,
                audio_data=None,
                input_text=None,
                request_obj=None,
                **kwargs,
            ):
                # Allow tests to simulate latency via kwargs
                delay = kwargs.get("delay_s", 0.0)
                if delay:
                    await asyncio.sleep(delay)
                return {
                    "path": "async",
                    "images": image_data,
                    "audios": audio_data,
                    "text": input_text,
                    "request": request_obj,
                    "kwargs": kwargs,
                }

        return AsyncProc()

    @pytest.fixture
    def sync_processor(self):
        """Provide a processor exposing a sync process_mm_data."""

        class SyncProc:
            def process_mm_data(
                self,
                *,
                image_data=None,
                audio_data=None,
                input_text=None,
                request_obj=None,
                **kwargs,
            ):
                delay = kwargs.get("delay_s", 0.0)
                if delay:
                    # Simulate CPU/blocking work
                    time.sleep(delay)
                return {
                    "path": "sync",
                    "images": image_data,
                    "audios": audio_data,
                    "text": input_text,
                    "request": request_obj,
                    "kwargs": kwargs,
                }

        return SyncProc()

    @pytest.mark.asyncio
    async def test_async_path_basic(self, async_processor):
        """Async processor should be awaited directly."""
        proc = AsyncMMDataProcessor(async_processor)
        out = await proc.process(
            image_data=["img1.png"],
            audio_data=["a.wav"],
            input_text_or_ids="hello",
            request_obj={"rid": 1},
            mode="fast",
        )
        assert out["path"] == "async"
        assert out["images"] == ["img1.png"]
        assert out["audios"] == ["a.wav"]
        assert out["text"] == "hello"
        assert out["request"] == {"rid": 1}
        assert out["kwargs"]["mode"] == "fast"

    @pytest.mark.asyncio
    async def test_sync_fallback_basic(self, sync_processor):
        """Sync processor should run in fallback executor."""
        proc = AsyncMMDataProcessor(sync_processor)
        out = await proc.process(
            image_data=[b"\x00\x01"],
            audio_data=None,
            input_text_or_ids=[1, 2, 3],
            request_obj="req-obj",
            role="user",
        )
        assert out["path"] == "sync"
        assert out["images"] == [b"\x00\x01"]
        assert out["audios"] is None
        assert out["text"] == [1, 2, 3]
        assert out["request"] == "req-obj"
        assert out["kwargs"]["role"] == "user"

    @pytest.mark.asyncio
    async def test_timeout_async(self, async_processor):
        """Timeout should raise asyncio.TimeoutError for async path."""
        proc = AsyncMMDataProcessor(async_processor, timeout_s=0.01)
        with pytest.raises(asyncio.TimeoutError):
            await proc.process(
                input_text_or_ids="slow",
                request_obj=None,
                delay_s=0.05,  # longer than timeout
            )

    @pytest.mark.asyncio
    async def test_timeout_sync(self, sync_processor):
        """Timeout should raise asyncio.TimeoutError for sync fallback path."""
        proc = AsyncMMDataProcessor(sync_processor, timeout_s=0.01)
        with pytest.raises(asyncio.TimeoutError):
            await proc.process(
                input_text_or_ids="slow",
                request_obj=None,
                delay_s=0.05,  # longer than timeout
            )

    @pytest.mark.asyncio
    async def test_semaphore_release_after_timeout(self, sync_processor):
        """
        If a call times out, the semaphore should be released so a subsequent call can proceed.
        Use >=2 fallback workers so the timed-out thread doesn't block the next call.
        """
        proc = AsyncMMDataProcessor(
            sync_processor,
            max_concurrent_calls=2,
            timeout_s=0.01,
        )

        # First call will time out
        with pytest.raises(asyncio.TimeoutError):
            await proc.process(
                input_text_or_ids="slow1", request_obj=None, delay_s=0.05
            )

        # Second call should be able to acquire the semaphore and complete
        out = await proc.process(input_text_or_ids="ok", request_obj=None, delay_s=0.0)
        assert out["text"] == "ok"

    @pytest.mark.asyncio
    async def test_concurrency_limit_async(self):
        """Ensure max_concurrent_calls caps concurrency for async path."""
        current = 0
        max_seen = 0

        class AsyncProc:
            async def process_mm_data_async(self, **kwargs):
                nonlocal current, max_seen
                current += 1
                max_seen = max(max_seen, current)
                try:
                    await asyncio.sleep(0.02)
                    return {"ok": True}
                finally:
                    current -= 1

        proc = AsyncMMDataProcessor(AsyncProc(), max_concurrent_calls=2)

        tasks = [
            proc.process(input_text_or_ids=f"t{i}", request_obj=None) for i in range(6)
        ]
        await asyncio.gather(*tasks)

        assert max_seen <= 2

    @pytest.mark.asyncio
    async def test_concurrency_limit_sync(self):
        """Ensure max_concurrent_calls caps concurrency for sync fallback path."""
        current = 0
        max_seen = 0
        lock = threading.Lock()

        class SyncProc:
            def process_mm_data(self, **kwargs):
                nonlocal current, max_seen
                with lock:
                    current += 1
                    max_seen = max(max_seen, current)
                try:
                    time.sleep(0.02)
                    return {"ok": True}
                finally:
                    with lock:
                        current -= 1

        proc = AsyncMMDataProcessor(SyncProc(), max_concurrent_calls=3)

        tasks = [
            proc.process(input_text_or_ids=f"s{i}", request_obj=None) for i in range(9)
        ]
        await asyncio.gather(*tasks)

        assert max_seen <= 3

    @pytest.mark.asyncio
    async def test_error_from_async_processor(self):
        """Exceptions raised by the async processor should propagate."""

        class BadAsync:
            async def process_mm_data_async(self, **_):
                await asyncio.sleep(0)
                raise ValueError("async boom")

        proc = AsyncMMDataProcessor(BadAsync())
        with pytest.raises(ValueError, match="async boom"):
            await proc.process(input_text_or_ids="x", request_obj=None)

    @pytest.mark.asyncio
    async def test_error_from_sync_processor(self):
        """Exceptions raised by the sync processor should propagate."""

        class BadSync:
            def process_mm_data(self, **_):
                raise RuntimeError("sync boom")

        proc = AsyncMMDataProcessor(BadSync())
        with pytest.raises(RuntimeError, match="sync boom"):
            await proc.process(input_text_or_ids="x", request_obj=None)

    @pytest.mark.asyncio
    async def test_missing_both_methods_raises(self):
        """Processor missing both methods should raise at call time."""

        class Empty:
            pass

        proc = AsyncMMDataProcessor(Empty())
        with pytest.raises(
            RuntimeError, match="neither 'process_mm_data_async' nor 'process_mm_data'"
        ):
            await proc.process(input_text_or_ids="x", request_obj=None)

    @pytest.mark.asyncio
    async def test_async_attribute_not_coroutine_uses_sync_fallback(self):
        """
        If `process_mm_data_async` exists but isn't a coroutine function,
        wrapper should treat it as sync and use `process_mm_data`.
        """

        class WeirdProc:
            # Not a coroutine function:
            def process_mm_data_async(self, **_):
                return {"path": "would-be-async"}

            def process_mm_data(self, **_):
                return {"path": "sync"}

        proc = AsyncMMDataProcessor(WeirdProc())
        out = await proc.process(input_text_or_ids="x", request_obj=None)
        assert out["path"] == "sync"

    @pytest.mark.asyncio
    async def test_kwargs_and_request_passthrough_async(self, async_processor):
        """Extra kwargs and request_obj should be forwarded on async path."""
        proc = AsyncMMDataProcessor(async_processor)
        out = await proc.process(
            image_data=["i1", "i2"],
            audio_data=["a1"],
            input_text_or_ids="hello world",
            request_obj={"uid": 42},
            return_meta=True,
            delay_s=0.0,
        )
        assert out["images"] == ["i1", "i2"]
        assert out["audios"] == ["a1"]
        assert out["text"] == "hello world"
        assert out["request"] == {"uid": 42}
        assert out["kwargs"]["return_meta"] is True

    @pytest.mark.asyncio
    async def test_kwargs_and_request_passthrough_sync(self, sync_processor):
        """Extra kwargs and request_obj should be forwarded on sync path."""
        proc = AsyncMMDataProcessor(sync_processor)
        out = await proc.process(
            image_data=None,
            audio_data=[],
            input_text_or_ids=[101, 102],
            request_obj=("r", 7),
            lang="en",
        )
        assert out["images"] is None
        assert out["audios"] == []
        assert out["text"] == [101, 102]
        assert out["request"] == ("r", 7)
        assert out["kwargs"]["lang"] == "en"

    def test_shutdown_on_sync_executor(self, sync_processor):
        """Explicit shutdown should close fallback executor for sync path."""
        proc = AsyncMMDataProcessor(sync_processor)
        # Swap real executor for a mock to assert shutdown behavior
        proc.fallback_exec = Mock()
        proc.shutdown()
        proc.fallback_exec.shutdown.assert_called_once_with(wait=False)

    def test_del_calls_shutdown(self, sync_processor, caplog):
        """__del__ should best-effort shutdown without raising."""
        caplog.set_level(logging.DEBUG)
        proc = AsyncMMDataProcessor(sync_processor)
        proc.fallback_exec = Mock()
        # Simulate object destruction
        proc.__del__()
        proc.fallback_exec.shutdown.assert_called_once_with(wait=False)

    @pytest.mark.asyncio
    async def test_concurrent_mixed_requests(self, async_processor):
        """Mix different payloads and ensure all complete with valid outputs."""
        proc = AsyncMMDataProcessor(async_processor, max_concurrent_calls=4)

        tasks = [
            proc.process(input_text_or_ids="t1", request_obj=1),
            proc.process(image_data=["i.png"], input_text_or_ids=[9, 8], request_obj=2),
            proc.process(
                audio_data=["v.wav"], input_text_or_ids="speech", request_obj=3
            ),
            proc.process(
                image_data=[], audio_data=[], input_text_or_ids=None, request_obj=4
            ),
        ]
        outs = await asyncio.gather(*tasks)
        assert len(outs) == 4
        for out in outs:
            assert "path" in out
            assert out["path"] == "async"

    @pytest.mark.asyncio
    async def test_many_requests_values_match_inputs(self, sync_processor):
        """For sync path, ensure each response corresponds to its specific input."""
        proc = AsyncMMDataProcessor(sync_processor, max_concurrent_calls=8)
        texts = [f"msg-{i}" for i in range(10)]
        tasks = [
            proc.process(input_text_or_ids=t, request_obj=i)
            for i, t in enumerate(texts)
        ]
        outs = await asyncio.gather(*tasks)
        got = [o["text"] for o in outs]
        assert got == texts


if __name__ == "__main__":
    pytest.main([__file__])