test_async_llm.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
from contextlib import ExitStack
6
from unittest.mock import MagicMock
7
8
9
10

import pytest

from vllm import SamplingParams
11
from vllm.assets.image import ImageAsset
12
from vllm.config import VllmConfig
13
from vllm.engine.arg_utils import AsyncEngineArgs
14
from vllm.entrypoints.openai.chat_completion.protocol import (
15
16
    ChatCompletionRequest,
    ChatCompletionResponse,
17
18
19
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
20
21
    ErrorResponse,
)
22
23
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
24
from vllm.inputs import PromptType
25
from vllm.outputs import RequestOutput
26
from vllm.platforms import current_platform
27
from vllm.sampling_params import RequestOutputKind
28
from vllm.utils.torch_utils import set_default_torch_num_threads
29
from vllm.v1.engine.async_llm import AsyncLLM
30
31
32
33
34
35
from vllm.v1.metrics.loggers import (
    AggregatedLoggingStatLogger,
    LoggingStatLogger,
    PerEngineStatLoggerAdapter,
    PrometheusStatLogger,
)
36
37

if not current_platform.is_cuda():
38
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
39

40
41
42
43
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
44

45
46
47
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
48
49
50
51
52
53
54

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
55
56
    "<|im_start|>assistant\n"
)
57
58
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
59
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
60
}
61
62


63
64
65
66
67
68
69
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
70
71
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
72
) -> tuple[int, str]:
73
74
75
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

76
    count = 0
77
78
79
80
81
82
83
84
85
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
86
87
88
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
89
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
90
91
92
93
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
94

95
96
97
98
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
99
100
101
102

    return count, request_id


103
@pytest.mark.parametrize(
104
105
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
106
107
108
109
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
110
@pytest.mark.asyncio
111
112
113
114
115
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
116
    with ExitStack() as after:
117
118
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
119
        after.callback(engine.shutdown)
120

121
        NUM_REQUESTS = 100
122
123
124
125
126
127
128
129
130
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
131
132
133
134
135
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
136
137

        # Confirm that we got all the EXPECTED tokens from the requests.
138
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
139
140
141
        for task in pending:
            task.cancel()
        for task in done:
142
            num_generated_tokens, request_id = await task
143
144
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
145
146
                f"expected {NUM_EXPECTED_TOKENS}"
            )
147
148
149
150

        assert not engine.output_processor.has_unfinished_requests()


151
@pytest.mark.parametrize(
152
153
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
154
155
156
157
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
158
@pytest.mark.asyncio
159
160
161
162
163
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
164
    with ExitStack() as after:
165
166
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
167
        after.callback(engine.shutdown)
168
169
170

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
171
        NUM_EXPECTED_TOKENS_LONG = 50000
172
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
173
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
174
175
176
177

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
178
        tasks: list[asyncio.Task] = []
179
        for idx, request_id in enumerate(request_ids):
180
181
182
183
184
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
185
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
186
187
            tasks.append(
                asyncio.create_task(
188
189
190
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
206
207
208
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
209
                    f"{request_id} generated {num_generated_tokens} but "
210
211
                    f"expected {expected_tokens}"
                )
212

213
        # Make sure all aborted requests were really aborted.
214
215
216
217
218
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
219
220
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
221
222
223
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
224
225


226
@pytest.mark.parametrize(
227
228
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
229
@pytest.mark.asyncio
230
231
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
247
248
249
250
251
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
252
253
254
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
255
256
257
258
259
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
260
261
262
263
264
265

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
266
        await engine.abort(abort_request_ids, internal=False)
267
268
269
270
271
272
273
274

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
275
276
277
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
278
279
280
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
281
282
                    f"Aborted request {request_id} should have generated some tokens"
                )
283
284
            else:
                # Non-aborted requests should complete normally
285
286
287
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
288
289
290
291
292
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
293
294
                    f"expected {expected_tokens}"
                )
295
296
297
298
299

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


300
@pytest.mark.parametrize("n", [1, 3])
301
302
303
304
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
305
@pytest.mark.asyncio
306
307
308
309
310
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
311
    with ExitStack() as after:
312
313
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
314
315
        after.callback(engine.shutdown)

316
317
318
319
320
321
322
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
323
324
        outputs = [
            out
325
326
327
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
328
329
330
331
332
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
333
334


335
336
337
338
339
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
340
async def test_mid_stream_cancellation(
341
    engine_args: AsyncEngineArgs, prompt: PromptType
342
):
343
    """Test that requests can be cancelled mid-stream."""
344
    with ExitStack() as after:
345
346
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
367
368
369
                    )
                )
            )
370
371
372
373
374
375
376
377

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
378
379
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
380
381
382
383
384
385
386

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
387
388
389
390
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
391
392
393
394
395
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


396
397
398
399
400
401
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


402
403
404
405
406
407
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
        super().__init__(vllm_config, engine_indexes)
        self.log = MagicMock()


408
409
410
411
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
412
    be added to the default loggers.
413
414
    """

415
    with ExitStack() as after:
416
417
418
419
420
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
421
422
423
424
        after.callback(engine.shutdown)

        await engine.do_log_stats()

425
426
427
428
429
430
431
432
433
434
435
436
        stat_loggers = engine.logger_manager.stat_loggers
        assert (
            len(stat_loggers) == 3
        )  # MockLoggingStatLogger + LoggingStatLogger +  Promethus Logger
        print(f"{stat_loggers=}")
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
437
async def test_customize_aggregated_loggers():
438
439
440
441
    """Test that we can customize the aggregated loggers.
    If a customized logger is provided at the init, it should
    be added to the default loggers.
    """
442
    with ExitStack() as after:
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
            )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        stat_loggers = engine.logger_manager.stat_loggers
        assert len(stat_loggers) == 4
        #  MockLoggingStatLogger + MockAggregatedStatLogger
        # + LoggingStatLogger + PrometheusStatLogger
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        stat_loggers[1].log.assert_called_once()
        assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[3], PrometheusStatLogger)
461
462
463


@pytest.mark.asyncio(scope="module")
464
465
async def test_dp_rank_argument():
    with ExitStack() as after:
466
467
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
468
469
        after.callback(engine.shutdown)

470
471
472
473
474
475
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
476
477

        # Test with valid DP rank.
478
479
480
481
482
483
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
484
485
486
487
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
488
489
490
491
492
493
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
494
                pass
495
496


497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        MODEL_NAME = "test-model"
        BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

        # Create models first
        models = OpenAIServingModels(
            engine_client=engine,
            base_model_paths=BASE_MODEL_PATHS,
        )

        # Create serving chat instance
        serving_chat = OpenAIServingChat(
            engine_client=engine,
            models=models,
            response_role="assistant",
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
        )
        # Create a chat completion request
        req = ChatCompletionRequest(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": TEXT_PROMPT}],
            max_tokens=100,
            temperature=1.0,
            seed=33,
        )
        # Test 1: Valid DP rank (0)
        mock_raw_request = MagicMock()
        mock_raw_request.headers = {"X-data-parallel-rank": "0"}
        mock_raw_request.state = MagicMock()

        # Should succeed with valid rank
        response = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response, ChatCompletionResponse), (
            "Expected a ChatCompletionResponse for valid DP rank"
        )

        # Test 2: Out-of-range DP rank (1)
        mock_raw_request.headers = {"X-data-parallel-rank": "1"}

        # should return ErrorResponse for out-of-range rank
        response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response2, ErrorResponse), (
            "Expected an ErrorResponse for out-of-range DP rank"
        )


551
@pytest.mark.asyncio
552
async def test_check_health():
553
554
555
556
557
558
559
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

560
    with ExitStack() as after:
561
562
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
563
564
565
566
567
568
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
569
570
571
572
573
574
575
576
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
577
578
579
580
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
581
582
583


@pytest.mark.parametrize(
584
585
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
586
@pytest.mark.asyncio
587
async def test_abort_final_output(output_kind: RequestOutputKind):
588
589
    """Test that abort() returns a final output with correct information."""

590
    with ExitStack() as after:
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
608
609
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
610
611
612
613
614

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
615
        await engine.abort(request_id, internal=False)
616
617
618
619
620
621
622
623
624
625
626
627
628

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
629
        assert hasattr(final_output, "num_cached_tokens")
630
631
632
633
634
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
635
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
636
            assert token_count > 0
637
638
639
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
640
641
642
643
644
645
646
647
648
649
650
651
652
653
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
654
) -> RequestOutput | None:
655
    """Helper to collect outputs and return the final one."""
656
    final_output: RequestOutput | None = None
657
658
659
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
660
661
662
663
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output