test_async_llm.py 14.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from typing import Optional
7
from unittest.mock import MagicMock
8

zhuwenwen's avatar
zhuwenwen committed
9
import os
10
11
12
import pytest

from vllm import SamplingParams
13
from vllm.assets.image import ImageAsset
14
from vllm.config import VllmConfig
15
from vllm.engine.arg_utils import AsyncEngineArgs
16
from vllm.inputs import PromptType
17
from vllm.platforms import current_platform
18
from vllm.sampling_params import RequestOutputKind
19
from vllm.utils import set_default_torch_num_threads
20
from vllm.v1.engine.async_llm import AsyncLLM
21
from vllm.v1.metrics.loggers import LoggingStatLogger
zhuwenwen's avatar
zhuwenwen committed
22
from ...utils import models_path_prefix
23
24
25
26
27

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

28
TEXT_ENGINE_ARGS = AsyncEngineArgs(
zhuwenwen's avatar
zhuwenwen committed
29
    model=os.path.join(models_path_prefix, "meta-llama/Llama-3.2-1B-Instruct"),
30
31
32
    enforce_eager=True,
    disable_log_requests=True,
)
33

34
35
36
VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
                                     enforce_eager=True,
                                     disable_log_requests=True)
37

38
39
40
41
42
43
44
45
46
47
48
TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
    "<|im_start|>assistant\n")
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
    "multi_modal_data": {
        "image": ImageAsset("stop_sign").pil_image
49
    },
50
}
51
52


53
54
55
56
57
58
59
60
61
62
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
    prompt_logprobs: Optional[int] = None,
    cancel_after: Optional[int] = None,
) -> tuple[int, str]:
63
64
65
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

66
    count = 0
67
68
69
70
71
72
73
74
75
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
76
    async for out in engine.generate(request_id=request_id,
77
                                     prompt=prompt,
78
79
                                     sampling_params=sampling_params):

80
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
81
82
83
84
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
85

86
87
88
89
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
90
91
92
93

    return count, request_id


94
95
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
96
97
98
99
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
100
@pytest.mark.asyncio
101
102
103
104
105
106
async def test_load(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
107
108
109
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
110
    with monkeypatch.context() as m, ExitStack() as after:
111
112
        m.setenv("VLLM_USE_V1", "1")

113
114
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
115
        after.callback(engine.shutdown)
116

117
        NUM_REQUESTS = 100
118
119
120
121
122
123
124
125
126
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
127
                    generate(engine, request_id, prompt, output_kind,
128
                             NUM_EXPECTED_TOKENS)))
129
130

        # Confirm that we got all the EXPECTED tokens from the requests.
131
132
133
134
135
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
136
            num_generated_tokens, request_id = await task
137
138
139
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")
140

141
        assert not engine.output_processor.has_unfinished_requests()
142

143

144
145
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
146
147
148
149
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
150
@pytest.mark.asyncio
151
152
153
154
155
156
async def test_abort(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
157

158
    with monkeypatch.context() as m, ExitStack() as after:
159
160
        m.setenv("VLLM_USE_V1", "1")

161
162
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
163
        after.callback(engine.shutdown)
164
165
166

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
167
        NUM_EXPECTED_TOKENS_LONG = 50000
168
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
169
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
170
171
172
173

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
174
        tasks: list[asyncio.Task] = []
175
        for idx, request_id in enumerate(request_ids):
176
177
178
            max_tokens = (NUM_EXPECTED_TOKENS_LONG if
                          (idx
                           in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
179
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
180
181
            tasks.append(
                asyncio.create_task(
182
                    generate(engine, request_id, prompt, output_kind,
183
                             max_tokens, n)))
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
199
200
201
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
202
                    f"{request_id} generated {num_generated_tokens} but "
203
                    f"expected {expected_tokens}")
204

205
        # Make sure all aborted requests were really aborted.
206
207
208
209
210
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
211
212
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
213
214
215
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
216
217
218


@pytest.mark.parametrize("n", [1, 3])
219
220
221
222
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
223
@pytest.mark.asyncio
224
225
226
227
228
229
async def test_finished_flag(
    monkeypatch: pytest.MonkeyPatch,
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
230
231
232
233

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

234
235
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
236
237
        after.callback(engine.shutdown)

238
239
240
241
242
243
244
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
245
246
247
248
249
250
251
252
253
254
        outputs = [
            out
            async for out in engine.generate(request_id="request-33",
                                             prompt=prompt,
                                             sampling_params=sampling_params)
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
255
256


257
258
259
260
261
262
263
264
265
266
267
268
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
                                       engine_args: AsyncEngineArgs,
                                       prompt: PromptType):
    """Test that requests can be cancelled mid-stream."""
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

269
270
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
                    )))

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
                f"expected to cancel after {NUM_EXPECTED_TOKENS}")

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
            generate(engine, request_id, prompt, RequestOutputKind.DELTA,
                     NUM_EXPECTED_TOKENS))
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
class MockLoggingStatLogger(LoggingStatLogger):

    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
    be used directly.
    """

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

332
333
334
335
336
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
337
338
339
340
        after.callback(engine.shutdown)

        await engine.do_log_stats()

341
342
343
344
        stat_loggers = engine.logger_manager.per_engine_logger_dict
        assert len(stat_loggers) == 1
        assert len(stat_loggers[0]) == 1
        stat_loggers[0][0].log.assert_called_once()
345
346
347
348
349
350
351


@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

352
353
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=100,
                                         output_kind=RequestOutputKind.DELTA,
                                         temperature=1.0,
                                         seed=33)

        # Test with valid DP rank.
        async for _ in engine.generate(request_id="request-34",
                                       prompt=TEXT_PROMPT,
                                       sampling_params=sampling_params,
                                       data_parallel_rank=0):
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
            async for _ in engine.generate(request_id="request-35",
                                           prompt=TEXT_PROMPT,
                                           sampling_params=sampling_params,
                                           data_parallel_rank=1):
                pass
375
376
377
378
379
380
381
382
383
384
385
386
387
388


@pytest.mark.asyncio
async def test_check_health(monkeypatch: pytest.MonkeyPatch):
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

389
390
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
        with patch.object(type(engine),
                          'errored',
                          new_callable=lambda: property(lambda self: True)
                          ), pytest.raises(EngineDeadError):
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()