test_async_llm.py 20 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from typing import Optional
7
from unittest.mock import MagicMock
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.assets.image import ImageAsset
13
from vllm.config import VllmConfig
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.inputs import PromptType
16
from vllm.outputs import RequestOutput
17
from vllm.platforms import current_platform
18
from vllm.sampling_params import RequestOutputKind
19
from vllm.utils import set_default_torch_num_threads
20
from vllm.v1.engine.async_llm import AsyncLLM
21
from vllm.v1.metrics.loggers import LoggingStatLogger
22
23
24
25
26

if not current_platform.is_cuda():
    pytest.skip(reason="V1 currently only supported on CUDA.",
                allow_module_level=True)

27
28
29
30
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
31
32

VISION_ENGINE_ARGS = AsyncEngineArgs(model="Qwen/Qwen2-VL-2B-Instruct",
33
                                     enforce_eager=True)
34
35
36
37
38
39
40
41
42
43
44
45

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
    "<|im_start|>assistant\n")
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
    "multi_modal_data": {
        "image": ImageAsset("stop_sign").pil_image
46
    },
47
}
48
49


50
51
52
53
54
55
56
57
58
59
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
    prompt_logprobs: Optional[int] = None,
    cancel_after: Optional[int] = None,
) -> tuple[int, str]:
60
61
62
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

63
    count = 0
64
65
66
67
68
69
70
71
72
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
73
    async for out in engine.generate(request_id=request_id,
74
                                     prompt=prompt,
75
76
                                     sampling_params=sampling_params):

77
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
78
79
80
81
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
82

83
84
85
86
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
87
88
89
90

    return count, request_id


91
92
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
93
94
95
96
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
97
@pytest.mark.asyncio
98
99
100
101
102
103
async def test_load(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
104
105
106
    # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
    # so that in the future when we switch, we don't have to change all the
    # tests.
107
    with monkeypatch.context() as m, ExitStack() as after:
108
109
        m.setenv("VLLM_USE_V1", "1")

110
111
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
112
        after.callback(engine.shutdown)
113

114
        NUM_REQUESTS = 100
115
116
117
118
119
120
121
122
123
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
124
                    generate(engine, request_id, prompt, output_kind,
125
                             NUM_EXPECTED_TOKENS)))
126
127

        # Confirm that we got all the EXPECTED tokens from the requests.
128
129
130
131
132
        done, pending = await asyncio.wait(tasks,
                                           return_when=asyncio.FIRST_EXCEPTION)
        for task in pending:
            task.cancel()
        for task in done:
133
            num_generated_tokens, request_id = await task
134
135
136
137
138
139
140
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
                f"expected {NUM_EXPECTED_TOKENS}")

        assert not engine.output_processor.has_unfinished_requests()


141
142
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
143
144
145
146
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
147
@pytest.mark.asyncio
148
149
150
151
152
153
async def test_abort(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
154

155
    with monkeypatch.context() as m, ExitStack() as after:
156
157
        m.setenv("VLLM_USE_V1", "1")

158
159
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
160
        after.callback(engine.shutdown)
161
162
163

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
164
        NUM_EXPECTED_TOKENS_LONG = 50000
165
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
166
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
167
168
169
170

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
171
        tasks: list[asyncio.Task] = []
172
        for idx, request_id in enumerate(request_ids):
173
174
175
            max_tokens = (NUM_EXPECTED_TOKENS_LONG if
                          (idx
                           in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
176
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
177
178
            tasks.append(
                asyncio.create_task(
179
                    generate(engine, request_id, prompt, output_kind,
180
                             max_tokens, n)))
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
196
197
198
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
199
                    f"{request_id} generated {num_generated_tokens} but "
200
                    f"expected {expected_tokens}")
201

202
        # Make sure all aborted requests were really aborted.
203
204
205
206
207
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
208
209
            generate(engine, request_id, prompt, output_kind,
                     NUM_EXPECTED_TOKENS))
210
211
212
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
213
214


215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_multi_abort(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
):

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
            max_tokens = (NUM_EXPECTED_TOKENS_LONG if
                          (idx
                           in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS)
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
                    generate(engine, request_id, TEXT_PROMPT, output_kind,
                             max_tokens, n)))

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
        await engine.abort(abort_request_ids)

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
                assert isinstance(
                    result, tuple
                ), f"Request {idx} should have completed with partial results"
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
                    f"Aborted request "
                    f"{request_id} should have generated some tokens")
            else:
                # Non-aborted requests should complete normally
                assert isinstance(
                    result,
                    tuple), f"Request {idx} should have completed successfully"
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
                    f"expected {expected_tokens}")

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


288
@pytest.mark.parametrize("n", [1, 3])
289
290
291
292
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
293
@pytest.mark.asyncio
294
295
296
297
298
299
async def test_finished_flag(
    monkeypatch: pytest.MonkeyPatch,
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
300
301
302
303

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

304
305
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
306
307
        after.callback(engine.shutdown)

308
309
310
311
312
313
314
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
315
316
317
318
319
320
321
322
323
324
        outputs = [
            out
            async for out in engine.generate(request_id="request-33",
                                             prompt=prompt,
                                             sampling_params=sampling_params)
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
325
326


327
328
329
330
331
332
333
334
335
336
337
338
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
                                       engine_args: AsyncEngineArgs,
                                       prompt: PromptType):
    """Test that requests can be cancelled mid-stream."""
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

339
340
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
                    )))

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
                f"expected to cancel after {NUM_EXPECTED_TOKENS}")

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
            generate(engine, request_id, prompt, RequestOutputKind.DELTA,
                     NUM_EXPECTED_TOKENS))
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class MockLoggingStatLogger(LoggingStatLogger):

    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
    be used directly.
    """

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

402
403
404
405
406
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
407
408
409
410
        after.callback(engine.shutdown)

        await engine.do_log_stats()

411
412
413
414
        stat_loggers = engine.logger_manager.per_engine_logger_dict
        assert len(stat_loggers) == 1
        assert len(stat_loggers[0]) == 1
        stat_loggers[0][0].log.assert_called_once()
415
416
417
418
419
420
421


@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

422
423
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=100,
                                         output_kind=RequestOutputKind.DELTA,
                                         temperature=1.0,
                                         seed=33)

        # Test with valid DP rank.
        async for _ in engine.generate(request_id="request-34",
                                       prompt=TEXT_PROMPT,
                                       sampling_params=sampling_params,
                                       data_parallel_rank=0):
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
            async for _ in engine.generate(request_id="request-35",
                                           prompt=TEXT_PROMPT,
                                           sampling_params=sampling_params,
                                           data_parallel_rank=1):
                pass
445
446
447
448
449
450
451
452
453
454
455
456
457
458


@pytest.mark.asyncio
async def test_check_health(monkeypatch: pytest.MonkeyPatch):
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

459
460
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
        with patch.object(type(engine),
                          'errored',
                          new_callable=lambda: property(lambda self: True)
                          ), pytest.raises(EngineDeadError):
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535


@pytest.mark.parametrize(
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_abort_final_output(
    monkeypatch: pytest.MonkeyPatch,
    output_kind: RequestOutputKind,
):
    """Test that abort() returns a final output with correct information."""

    with monkeypatch.context() as m, ExitStack() as after:
        m.setenv("VLLM_USE_V1", "1")

        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params,
                            outputs))

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
        await engine.abort(request_id)

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
        assert hasattr(final_output, 'num_cached_tokens')
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
            token_count = sum(
                len(output.outputs[0].token_ids) for output in outputs)
            assert token_count > 0
536
537
538
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
) -> Optional[RequestOutput]:
    """Helper to collect outputs and return the final one."""
    final_output: Optional[RequestOutput] = None
    async for output in engine.generate(request_id=request_id,
                                        prompt=prompt,
                                        sampling_params=sampling_params):
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output