test_async_llm.py 22.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from unittest.mock import MagicMock
7
8
9
10

import pytest

from vllm import SamplingParams
11
from vllm.assets.image import ImageAsset
12
from vllm.config import VllmConfig
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.entrypoints.openai.chat_completion.protocol import (
15
16
    ChatCompletionRequest,
    ChatCompletionResponse,
17
18
19
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
20
21
22
    ErrorResponse,
)
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
23
from vllm.inputs import PromptType
24
from vllm.outputs import RequestOutput
25
from vllm.platforms import current_platform
26
from vllm.sampling_params import RequestOutputKind
27
from vllm.utils.torch_utils import set_default_torch_num_threads
28
from vllm.v1.engine.async_llm import AsyncLLM
29
30
31
32
33
34
from vllm.v1.metrics.loggers import (
    AggregatedLoggingStatLogger,
    LoggingStatLogger,
    PerEngineStatLoggerAdapter,
    PrometheusStatLogger,
)
35
36

if not current_platform.is_cuda():
37
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
38

39
40
41
42
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
43

44
45
46
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
47
48
49
50
51
52
53

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
54
55
    "<|im_start|>assistant\n"
)
56
57
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
58
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
59
}
60
61


62
63
64
65
66
67
68
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
69
70
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
71
) -> tuple[int, str]:
72
73
74
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

75
    count = 0
76
77
78
79
80
81
82
83
84
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
85
86
87
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
88
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
89
90
91
92
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
93

94
95
96
97
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
98
99
100
101

    return count, request_id


102
@pytest.mark.parametrize(
103
104
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
105
106
107
108
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
109
@pytest.mark.asyncio
110
111
112
113
114
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
115
    with ExitStack() as after:
116
117
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
118
        after.callback(engine.shutdown)
119

120
        NUM_REQUESTS = 100
121
122
123
124
125
126
127
128
129
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
130
131
132
133
134
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
135
136

        # Confirm that we got all the EXPECTED tokens from the requests.
137
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
138
139
140
        for task in pending:
            task.cancel()
        for task in done:
141
            num_generated_tokens, request_id = await task
142
143
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
144
145
                f"expected {NUM_EXPECTED_TOKENS}"
            )
146
147
148
149

        assert not engine.output_processor.has_unfinished_requests()


150
@pytest.mark.parametrize(
151
152
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
153
154
155
156
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
157
@pytest.mark.asyncio
158
159
160
161
162
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
163
    with ExitStack() as after:
164
165
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
166
        after.callback(engine.shutdown)
167
168
169

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
170
        NUM_EXPECTED_TOKENS_LONG = 50000
171
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
172
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
173
174
175
176

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
177
        tasks: list[asyncio.Task] = []
178
        for idx, request_id in enumerate(request_ids):
179
180
181
182
183
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
184
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
185
186
            tasks.append(
                asyncio.create_task(
187
188
189
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
205
206
207
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
208
                    f"{request_id} generated {num_generated_tokens} but "
209
210
                    f"expected {expected_tokens}"
                )
211

212
        # Make sure all aborted requests were really aborted.
213
214
215
216
217
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
218
219
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
220
221
222
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
223
224


225
@pytest.mark.parametrize(
226
227
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
228
@pytest.mark.asyncio
229
230
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
246
247
248
249
250
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
251
252
253
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
254
255
256
257
258
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
259
260
261
262
263
264

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
265
        await engine.abort(abort_request_ids, internal=False)
266
267
268
269
270
271
272
273

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
274
275
276
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
277
278
279
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
280
281
                    f"Aborted request {request_id} should have generated some tokens"
                )
282
283
            else:
                # Non-aborted requests should complete normally
284
285
286
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
287
288
289
290
291
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
292
293
                    f"expected {expected_tokens}"
                )
294
295
296
297
298

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


299
@pytest.mark.parametrize("n", [1, 3])
300
301
302
303
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
304
@pytest.mark.asyncio
305
306
307
308
309
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
310
    with ExitStack() as after:
311
312
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
313
314
        after.callback(engine.shutdown)

315
316
317
318
319
320
321
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
322
323
        outputs = [
            out
324
325
326
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
327
328
329
330
331
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
332
333


334
335
336
337
338
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
339
async def test_mid_stream_cancellation(
340
    engine_args: AsyncEngineArgs, prompt: PromptType
341
):
342
    """Test that requests can be cancelled mid-stream."""
343
    with ExitStack() as after:
344
345
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
366
367
368
                    )
                )
            )
369
370
371
372
373
374
375
376

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
377
378
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
379
380
381
382
383
384
385

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
386
387
388
389
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
390
391
392
393
394
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


395
396
397
398
399
400
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


401
402
403
404
405
406
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
        super().__init__(vllm_config, engine_indexes)
        self.log = MagicMock()


407
408
409
410
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
411
    be added to the default loggers.
412
413
    """

414
    with ExitStack() as after:
415
416
417
418
419
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
420
421
422
423
        after.callback(engine.shutdown)

        await engine.do_log_stats()

424
425
426
427
428
429
430
431
432
433
434
435
        stat_loggers = engine.logger_manager.stat_loggers
        assert (
            len(stat_loggers) == 3
        )  # MockLoggingStatLogger + LoggingStatLogger +  Promethus Logger
        print(f"{stat_loggers=}")
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
436
async def test_customize_aggregated_loggers():
437
438
439
440
    """Test that we can customize the aggregated loggers.
    If a customized logger is provided at the init, it should
    be added to the default loggers.
    """
441
    with ExitStack() as after:
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
            )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        stat_loggers = engine.logger_manager.stat_loggers
        assert len(stat_loggers) == 4
        #  MockLoggingStatLogger + MockAggregatedStatLogger
        # + LoggingStatLogger + PrometheusStatLogger
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        stat_loggers[1].log.assert_called_once()
        assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[3], PrometheusStatLogger)
460
461
462


@pytest.mark.asyncio(scope="module")
463
464
async def test_dp_rank_argument():
    with ExitStack() as after:
465
466
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
467
468
        after.callback(engine.shutdown)

469
470
471
472
473
474
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
475
476

        # Test with valid DP rank.
477
478
479
480
481
482
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
483
484
485
486
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
487
488
489
490
491
492
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
493
                pass
494
495


496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        MODEL_NAME = "test-model"
        BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

        # Create models first
        models = OpenAIServingModels(
            engine_client=engine,
            base_model_paths=BASE_MODEL_PATHS,
        )

        # Create serving chat instance
        serving_chat = OpenAIServingChat(
            engine_client=engine,
            models=models,
            response_role="assistant",
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
        )
        # Create a chat completion request
        req = ChatCompletionRequest(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": TEXT_PROMPT}],
            max_tokens=100,
            temperature=1.0,
            seed=33,
        )
        # Test 1: Valid DP rank (0)
        mock_raw_request = MagicMock()
        mock_raw_request.headers = {"X-data-parallel-rank": "0"}
        mock_raw_request.state = MagicMock()

        # Should succeed with valid rank
        response = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response, ChatCompletionResponse), (
            "Expected a ChatCompletionResponse for valid DP rank"
        )

        # Test 2: Out-of-range DP rank (1)
        mock_raw_request.headers = {"X-data-parallel-rank": "1"}

        # should return ErrorResponse for out-of-range rank
        response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response2, ErrorResponse), (
            "Expected an ErrorResponse for out-of-range DP rank"
        )


550
@pytest.mark.asyncio
551
async def test_check_health():
552
553
554
555
556
557
558
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

559
    with ExitStack() as after:
560
561
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
562
563
564
565
566
567
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
568
569
570
571
572
573
574
575
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
576
577
578
579
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
580
581
582


@pytest.mark.parametrize(
583
584
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
585
@pytest.mark.asyncio
586
async def test_abort_final_output(output_kind: RequestOutputKind):
587
588
    """Test that abort() returns a final output with correct information."""

589
    with ExitStack() as after:
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
607
608
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
609
610
611
612
613

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
614
        await engine.abort(request_id, internal=False)
615
616
617
618
619
620
621
622
623
624
625
626
627

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
628
        assert hasattr(final_output, "num_cached_tokens")
629
630
631
632
633
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
634
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
635
            assert token_count > 0
636
637
638
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
639
640
641
642
643
644
645
646
647
648
649
650
651
652
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
653
) -> RequestOutput | None:
654
    """Helper to collect outputs and return the final one."""
655
    final_output: RequestOutput | None = None
656
657
658
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
659
660
661
662
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output