"tests/vscode:/vscode.git/clone" did not exist on "7c139ab23f6d2e9b4603b40814956100a1ccf569"
test_async_llm.py 18.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from unittest.mock import MagicMock
7
8
9
10

import pytest

from vllm import SamplingParams
11
from vllm.assets.image import ImageAsset
12
from vllm.config import VllmConfig
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.inputs import PromptType
15
from vllm.outputs import RequestOutput
16
from vllm.platforms import current_platform
17
from vllm.sampling_params import RequestOutputKind
18
from vllm.utils import set_default_torch_num_threads
19
from vllm.v1.engine.async_llm import AsyncLLM
20
from vllm.v1.metrics.loggers import LoggingStatLogger
21
22

if not current_platform.is_cuda():
23
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
24

25
26
27
28
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
29

30
31
32
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
33
34
35
36
37
38
39

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
40
41
    "<|im_start|>assistant\n"
)
42
43
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
44
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
45
}
46
47


48
49
50
51
52
53
54
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
55
56
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
57
) -> tuple[int, str]:
58
59
60
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

61
    count = 0
62
63
64
65
66
67
68
69
70
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
71
72
73
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
74
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
75
76
77
78
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
79

80
81
82
83
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
84
85
86
87

    return count, request_id


88
@pytest.mark.parametrize(
89
90
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
91
92
93
94
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
95
@pytest.mark.asyncio
96
97
98
99
100
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
101
    with ExitStack() as after:
102
103
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
104
        after.callback(engine.shutdown)
105

106
        NUM_REQUESTS = 100
107
108
109
110
111
112
113
114
115
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
116
117
118
119
120
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
121
122

        # Confirm that we got all the EXPECTED tokens from the requests.
123
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
124
125
126
        for task in pending:
            task.cancel()
        for task in done:
127
            num_generated_tokens, request_id = await task
128
129
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
130
131
                f"expected {NUM_EXPECTED_TOKENS}"
            )
132
133
134
135

        assert not engine.output_processor.has_unfinished_requests()


136
@pytest.mark.parametrize(
137
138
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
139
140
141
142
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
143
@pytest.mark.asyncio
144
145
146
147
148
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
149
    with ExitStack() as after:
150
151
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
152
        after.callback(engine.shutdown)
153
154
155

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
156
        NUM_EXPECTED_TOKENS_LONG = 50000
157
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
158
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
159
160
161
162

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
163
        tasks: list[asyncio.Task] = []
164
        for idx, request_id in enumerate(request_ids):
165
166
167
168
169
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
170
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
171
172
            tasks.append(
                asyncio.create_task(
173
174
175
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
191
192
193
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
194
                    f"{request_id} generated {num_generated_tokens} but "
195
196
                    f"expected {expected_tokens}"
                )
197

198
        # Make sure all aborted requests were really aborted.
199
200
201
202
203
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
204
205
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
206
207
208
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
209
210


211
@pytest.mark.parametrize(
212
213
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
214
@pytest.mark.asyncio
215
216
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
232
233
234
235
236
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
237
238
239
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
240
241
242
243
244
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
        await engine.abort(abort_request_ids)

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
260
261
262
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
263
264
265
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
266
267
                    f"Aborted request {request_id} should have generated some tokens"
                )
268
269
            else:
                # Non-aborted requests should complete normally
270
271
272
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
273
274
275
276
277
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
278
279
                    f"expected {expected_tokens}"
                )
280
281
282
283
284

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


285
@pytest.mark.parametrize("n", [1, 3])
286
287
288
289
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
290
@pytest.mark.asyncio
291
292
293
294
295
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
296
    with ExitStack() as after:
297
298
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
299
300
        after.callback(engine.shutdown)

301
302
303
304
305
306
307
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
308
309
        outputs = [
            out
310
311
312
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
313
314
315
316
317
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
318
319


320
321
322
323
324
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
325
async def test_mid_stream_cancellation(
326
    engine_args: AsyncEngineArgs, prompt: PromptType
327
):
328
    """Test that requests can be cancelled mid-stream."""
329
    with ExitStack() as after:
330
331
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
352
353
354
                    )
                )
            )
355
356
357
358
359
360
361
362

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
363
364
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
365
366
367
368
369
370
371

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
372
373
374
375
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
376
377
378
379
380
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


381
382
383
384
385
386
387
388
389
390
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
391
    be added to the default loggers.
392
393
    """

394
    with ExitStack() as after:
395
396
397
398
399
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
400
401
402
403
        after.callback(engine.shutdown)

        await engine.do_log_stats()

404
405
        stat_loggers = engine.logger_manager.per_engine_logger_dict
        assert len(stat_loggers) == 1
406
        assert len(stat_loggers[0]) == 2  # LoggingStatLogger + MockLoggingStatLogger
407
        stat_loggers[0][0].log.assert_called_once()
408
409
410


@pytest.mark.asyncio(scope="module")
411
412
async def test_dp_rank_argument():
    with ExitStack() as after:
413
414
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
415
416
        after.callback(engine.shutdown)

417
418
419
420
421
422
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
423
424

        # Test with valid DP rank.
425
426
427
428
429
430
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
431
432
433
434
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
435
436
437
438
439
440
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
441
                pass
442
443
444


@pytest.mark.asyncio
445
async def test_check_health():
446
447
448
449
450
451
452
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

453
    with ExitStack() as after:
454
455
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
456
457
458
459
460
461
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
462
463
464
465
466
467
468
469
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
470
471
472
473
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
474
475
476


@pytest.mark.parametrize(
477
478
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
479
@pytest.mark.asyncio
480
async def test_abort_final_output(output_kind: RequestOutputKind):
481
482
    """Test that abort() returns a final output with correct information."""

483
    with ExitStack() as after:
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
501
502
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
        await engine.abort(request_id)

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
522
        assert hasattr(final_output, "num_cached_tokens")
523
524
525
526
527
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
528
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
529
            assert token_count > 0
530
531
532
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
547
) -> RequestOutput | None:
548
    """Helper to collect outputs and return the final one."""
549
    final_output: RequestOutput | None = None
550
551
552
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
553
554
555
556
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output