"vscode:/vscode.git/clone" did not exist on "5da4f5d857933329aaca779e3a81f1385c84e34a"
test_async_llm.py 20.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from unittest.mock import MagicMock
7
8
9
10

import pytest

from vllm import SamplingParams
11
from vllm.assets.image import ImageAsset
12
from vllm.config import VllmConfig
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.inputs import PromptType
15
from vllm.outputs import RequestOutput
16
from vllm.platforms import current_platform
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.utils.torch_utils import set_default_torch_num_threads
19
from vllm.v1.engine.async_llm import AsyncLLM
20
21
22
23
24
25
from vllm.v1.metrics.loggers import (
    AggregatedLoggingStatLogger,
    LoggingStatLogger,
    PerEngineStatLoggerAdapter,
    PrometheusStatLogger,
)
26
27

if not current_platform.is_cuda():
28
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
29

30
31
32
33
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
34

35
36
37
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
38
39
40
41
42
43
44

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
45
46
    "<|im_start|>assistant\n"
)
47
48
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
49
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
50
}
51
52


53
54
55
56
57
58
59
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
60
61
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
62
) -> tuple[int, str]:
63
64
65
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

66
    count = 0
67
68
69
70
71
72
73
74
75
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
76
77
78
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
79
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
80
81
82
83
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
84

85
86
87
88
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
89
90
91
92

    return count, request_id


93
@pytest.mark.parametrize(
94
95
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
96
97
98
99
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
100
@pytest.mark.asyncio
101
102
103
104
105
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
106
    with ExitStack() as after:
107
108
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
109
        after.callback(engine.shutdown)
110

111
        NUM_REQUESTS = 100
112
113
114
115
116
117
118
119
120
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
121
122
123
124
125
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
126
127

        # Confirm that we got all the EXPECTED tokens from the requests.
128
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
129
130
131
        for task in pending:
            task.cancel()
        for task in done:
132
            num_generated_tokens, request_id = await task
133
134
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
135
136
                f"expected {NUM_EXPECTED_TOKENS}"
            )
137
138
139
140

        assert not engine.output_processor.has_unfinished_requests()


141
@pytest.mark.parametrize(
142
143
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
144
145
146
147
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
148
@pytest.mark.asyncio
149
150
151
152
153
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
154
    with ExitStack() as after:
155
156
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
157
        after.callback(engine.shutdown)
158
159
160

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
161
        NUM_EXPECTED_TOKENS_LONG = 50000
162
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
163
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
164
165
166
167

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
168
        tasks: list[asyncio.Task] = []
169
        for idx, request_id in enumerate(request_ids):
170
171
172
173
174
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
175
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
176
177
            tasks.append(
                asyncio.create_task(
178
179
180
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
196
197
198
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
199
                    f"{request_id} generated {num_generated_tokens} but "
200
201
                    f"expected {expected_tokens}"
                )
202

203
        # Make sure all aborted requests were really aborted.
204
205
206
207
208
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
209
210
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
211
212
213
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
214
215


216
@pytest.mark.parametrize(
217
218
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
219
@pytest.mark.asyncio
220
221
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
237
238
239
240
241
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
242
243
244
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
245
246
247
248
249
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
        await engine.abort(abort_request_ids)

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
265
266
267
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
268
269
270
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
271
272
                    f"Aborted request {request_id} should have generated some tokens"
                )
273
274
            else:
                # Non-aborted requests should complete normally
275
276
277
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
278
279
280
281
282
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
283
284
                    f"expected {expected_tokens}"
                )
285
286
287
288
289

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


290
@pytest.mark.parametrize("n", [1, 3])
291
292
293
294
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
295
@pytest.mark.asyncio
296
297
298
299
300
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
301
    with ExitStack() as after:
302
303
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
304
305
        after.callback(engine.shutdown)

306
307
308
309
310
311
312
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
313
314
        outputs = [
            out
315
316
317
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
318
319
320
321
322
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
323
324


325
326
327
328
329
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
330
async def test_mid_stream_cancellation(
331
    engine_args: AsyncEngineArgs, prompt: PromptType
332
):
333
    """Test that requests can be cancelled mid-stream."""
334
    with ExitStack() as after:
335
336
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
357
358
359
                    )
                )
            )
360
361
362
363
364
365
366
367

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
368
369
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
370
371
372
373
374
375
376

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
377
378
379
380
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
381
382
383
384
385
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


386
387
388
389
390
391
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


392
393
394
395
396
397
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
        super().__init__(vllm_config, engine_indexes)
        self.log = MagicMock()


398
399
400
401
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
402
    be added to the default loggers.
403
404
    """

405
    with ExitStack() as after:
406
407
408
409
410
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
411
412
413
414
        after.callback(engine.shutdown)

        await engine.do_log_stats()

415
416
417
418
419
420
421
422
423
424
425
426
        stat_loggers = engine.logger_manager.stat_loggers
        assert (
            len(stat_loggers) == 3
        )  # MockLoggingStatLogger + LoggingStatLogger +  Promethus Logger
        print(f"{stat_loggers=}")
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
427
async def test_customize_aggregated_loggers():
428
429
430
431
    """Test that we can customize the aggregated loggers.
    If a customized logger is provided at the init, it should
    be added to the default loggers.
    """
432
    with ExitStack() as after:
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
            )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        stat_loggers = engine.logger_manager.stat_loggers
        assert len(stat_loggers) == 4
        #  MockLoggingStatLogger + MockAggregatedStatLogger
        # + LoggingStatLogger + PrometheusStatLogger
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        stat_loggers[1].log.assert_called_once()
        assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[3], PrometheusStatLogger)
451
452
453


@pytest.mark.asyncio(scope="module")
454
455
async def test_dp_rank_argument():
    with ExitStack() as after:
456
457
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
458
459
        after.callback(engine.shutdown)

460
461
462
463
464
465
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
466
467

        # Test with valid DP rank.
468
469
470
471
472
473
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
474
475
476
477
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
478
479
480
481
482
483
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
484
                pass
485
486
487


@pytest.mark.asyncio
488
async def test_check_health():
489
490
491
492
493
494
495
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

496
    with ExitStack() as after:
497
498
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
499
500
501
502
503
504
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
505
506
507
508
509
510
511
512
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
513
514
515
516
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
517
518
519


@pytest.mark.parametrize(
520
521
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
522
@pytest.mark.asyncio
523
async def test_abort_final_output(output_kind: RequestOutputKind):
524
525
    """Test that abort() returns a final output with correct information."""

526
    with ExitStack() as after:
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
544
545
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
        await engine.abort(request_id)

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
565
        assert hasattr(final_output, "num_cached_tokens")
566
567
568
569
570
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
571
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
572
            assert token_count > 0
573
574
575
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
590
) -> RequestOutput | None:
591
    """Helper to collect outputs and return the final one."""
592
    final_output: RequestOutput | None = None
593
594
595
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
596
597
598
599
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output