test_async_llm.py 32.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
from contextlib import ExitStack
7
from unittest.mock import MagicMock
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.assets.image import ImageAsset
13
from vllm.config import VllmConfig
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.entrypoints.openai.chat_completion.protocol import (
16
17
    ChatCompletionRequest,
    ChatCompletionResponse,
18
19
20
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
from vllm.entrypoints.openai.engine.protocol import (
21
22
    ErrorResponse,
)
23
24
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
25
from vllm.inputs import PromptType
26
from vllm.outputs import RequestOutput
27
from vllm.platforms import current_platform
28
from vllm.sampling_params import RequestOutputKind
29
from vllm.utils.torch_utils import set_default_torch_num_threads
30
from vllm.v1.engine.async_llm import AsyncLLM
31
32
33
34
35
36
from vllm.v1.metrics.loggers import (
    AggregatedLoggingStatLogger,
    LoggingStatLogger,
    PerEngineStatLoggerAdapter,
    PrometheusStatLogger,
)
37
38

if not current_platform.is_cuda():
39
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
40

41
42
43
44
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
45

46
47
48
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
49
50
51
52
53
54
55

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
56
57
    "<|im_start|>assistant\n"
)
58
59
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
60
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
61
}
62
63


64
65
66
67
68
69
70
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
71
72
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
73
) -> tuple[int, str]:
74
75
76
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

77
    count = 0
78
79
80
81
82
83
84
85
86
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
87
88
89
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
90
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
91
92
93
94
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
95

96
97
98
99
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
100
101
102
103

    return count, request_id


104
@pytest.mark.parametrize(
105
106
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
107
108
109
110
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
111
@pytest.mark.asyncio
112
113
114
115
116
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
117
    with ExitStack() as after:
118
119
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
120
        after.callback(engine.shutdown)
121

122
        NUM_REQUESTS = 100
123
124
125
126
127
128
129
130
131
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
132
133
134
135
136
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
137
138

        # Confirm that we got all the EXPECTED tokens from the requests.
139
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
140
141
142
        for task in pending:
            task.cancel()
        for task in done:
143
            num_generated_tokens, request_id = await task
144
145
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
146
147
                f"expected {NUM_EXPECTED_TOKENS}"
            )
148
149
150
151

        assert not engine.output_processor.has_unfinished_requests()


152
@pytest.mark.parametrize(
153
154
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
155
156
157
158
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
159
@pytest.mark.asyncio
160
161
162
163
164
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
165
    with ExitStack() as after:
166
167
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
168
        after.callback(engine.shutdown)
169
170
171

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
172
        NUM_EXPECTED_TOKENS_LONG = 50000
173
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
174
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
175
176
177
178

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
179
        tasks: list[asyncio.Task] = []
180
        for idx, request_id in enumerate(request_ids):
181
182
183
184
185
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
186
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
187
188
            tasks.append(
                asyncio.create_task(
189
190
191
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
207
208
209
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
210
                    f"{request_id} generated {num_generated_tokens} but "
211
212
                    f"expected {expected_tokens}"
                )
213

214
        # Make sure all aborted requests were really aborted.
215
216
217
218
219
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
220
221
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
222
223
224
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
225
226


227
@pytest.mark.parametrize(
228
229
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
230
@pytest.mark.asyncio
231
232
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
248
249
250
251
252
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
253
254
255
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
256
257
258
259
260
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
261
262
263
264
265
266

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
267
        await engine.abort(abort_request_ids, internal=False)
268
269
270
271
272
273
274
275

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
276
277
278
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
279
280
281
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
282
283
                    f"Aborted request {request_id} should have generated some tokens"
                )
284
285
            else:
                # Non-aborted requests should complete normally
286
287
288
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
289
290
291
292
293
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
294
295
                    f"expected {expected_tokens}"
                )
296
297
298
299
300

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


301
@pytest.mark.parametrize("n", [1, 3])
302
303
304
305
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
306
@pytest.mark.asyncio
307
308
309
310
311
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
312
    with ExitStack() as after:
313
314
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
315
316
        after.callback(engine.shutdown)

317
318
319
320
321
322
323
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
324
325
        outputs = [
            out
326
327
328
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
329
330
331
332
333
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
334
335


336
337
338
339
340
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
341
async def test_mid_stream_cancellation(
342
    engine_args: AsyncEngineArgs, prompt: PromptType
343
):
344
    """Test that requests can be cancelled mid-stream."""
345
    with ExitStack() as after:
346
347
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
368
369
370
                    )
                )
            )
371
372
373
374
375
376
377
378

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
379
380
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
381
382
383
384
385
386
387

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
388
389
390
391
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
392
393
394
395
396
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


397
398
399
400
401
402
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


403
404
405
406
407
408
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
        super().__init__(vllm_config, engine_indexes)
        self.log = MagicMock()


409
410
411
412
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
413
    be added to the default loggers.
414
415
    """

416
    with ExitStack() as after:
417
418
419
420
421
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
422
423
424
425
        after.callback(engine.shutdown)

        await engine.do_log_stats()

426
427
428
429
430
431
432
433
434
435
436
437
        stat_loggers = engine.logger_manager.stat_loggers
        assert (
            len(stat_loggers) == 3
        )  # MockLoggingStatLogger + LoggingStatLogger +  Promethus Logger
        print(f"{stat_loggers=}")
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
438
async def test_customize_aggregated_loggers():
439
440
441
442
    """Test that we can customize the aggregated loggers.
    If a customized logger is provided at the init, it should
    be added to the default loggers.
    """
443
    with ExitStack() as after:
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
            )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        stat_loggers = engine.logger_manager.stat_loggers
        assert len(stat_loggers) == 4
        #  MockLoggingStatLogger + MockAggregatedStatLogger
        # + LoggingStatLogger + PrometheusStatLogger
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        stat_loggers[1].log.assert_called_once()
        assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[3], PrometheusStatLogger)
462
463
464


@pytest.mark.asyncio(scope="module")
465
466
async def test_dp_rank_argument():
    with ExitStack() as after:
467
468
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
469
470
        after.callback(engine.shutdown)

471
472
473
474
475
476
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
477
478

        # Test with valid DP rank.
479
480
481
482
483
484
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
485
486
487
488
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
489
490
491
492
493
494
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
495
                pass
496
497


498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        MODEL_NAME = "test-model"
        BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

        # Create models first
        models = OpenAIServingModels(
            engine_client=engine,
            base_model_paths=BASE_MODEL_PATHS,
        )

        # Create serving chat instance
        serving_chat = OpenAIServingChat(
            engine_client=engine,
            models=models,
            response_role="assistant",
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
        )
        # Create a chat completion request
        req = ChatCompletionRequest(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": TEXT_PROMPT}],
            max_tokens=100,
            temperature=1.0,
            seed=33,
        )
        # Test 1: Valid DP rank (0)
        mock_raw_request = MagicMock()
        mock_raw_request.headers = {"X-data-parallel-rank": "0"}
        mock_raw_request.state = MagicMock()

        # Should succeed with valid rank
        response = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response, ChatCompletionResponse), (
            "Expected a ChatCompletionResponse for valid DP rank"
        )

        # Test 2: Out-of-range DP rank (1)
        mock_raw_request.headers = {"X-data-parallel-rank": "1"}

        # should return ErrorResponse for out-of-range rank
        response2 = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response2, ErrorResponse), (
            "Expected an ErrorResponse for out-of-range DP rank"
        )


552
@pytest.mark.asyncio
553
async def test_check_health():
554
555
556
557
558
559
560
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

561
    with ExitStack() as after:
562
563
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
564
565
566
567
568
569
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
570
571
572
573
574
575
576
577
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
578
579
580
581
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
582
583
584


@pytest.mark.parametrize(
585
586
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
587
@pytest.mark.asyncio
588
async def test_abort_final_output(output_kind: RequestOutputKind):
589
590
    """Test that abort() returns a final output with correct information."""

591
    with ExitStack() as after:
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
609
610
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
611
612
613
614
615

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
616
        await engine.abort(request_id, internal=False)
617
618
619
620
621
622
623
624
625
626
627
628
629

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
630
        assert hasattr(final_output, "num_cached_tokens")
631
632
633
634
635
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
636
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
637
            assert token_count > 0
638
639
640
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
641
642
643
644
645
646
647
648
649
650
651
652
653
654
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
655
) -> RequestOutput | None:
656
    """Helper to collect outputs and return the final one."""
657
    final_output: RequestOutput | None = None
658
659
660
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
661
662
663
664
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962


# =============================================================================
# Pause/Resume Tests
# =============================================================================


@pytest.mark.asyncio
async def test_pause_resume_basic():
    """Test basic pause/resume flag behavior and idempotency.

    Tests:
    - pause_generation sets the paused flag
    - resume_generation clears the paused flag
    - calling pause when already paused is a no-op
    - calling resume when not paused is safe
    - all pause modes work with no requests in flight
    - rapid pause/resume cycles don't break the engine
    """
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Initially not paused
        assert not await engine.is_paused()

        # Resume when not paused should be safe
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Pause sets flag
        await engine.pause_generation(mode="abort")
        assert await engine.is_paused()

        # Pause when already paused is a no-op
        await engine.pause_generation(mode="abort")
        assert await engine.is_paused()

        # Resume clears flag
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Test all modes with no requests in flight
        for mode in ("abort", "wait", "keep"):
            await engine.pause_generation(mode=mode)
            # "keep" only freezes the scheduler; it does not set _paused
            if mode != "keep":
                assert await engine.is_paused()
            await engine.resume_generation()
            assert not await engine.is_paused()

        # Concurrent pause/resume race conditions - should not deadlock or raise
        await asyncio.gather(
            engine.pause_generation(mode="abort"),
            engine.resume_generation(),
            engine.pause_generation(mode="abort"),
            engine.resume_generation(),
        )

        # Ensure we end in a known state
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Engine should still work after all cycles
        sampling_params = SamplingParams(max_tokens=5)
        async for out in engine.generate(
            request_id="post-cycles",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
        ):
            pass
        assert out.finished


@pytest.mark.asyncio
async def test_pause_abort():
    """Test that mode='abort' aborts in-flight requests immediately."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Start a long-running request
        sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
        outputs: list[RequestOutput] = []

        async def gen():
            async for out in engine.generate(
                request_id="test-abort-pause",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                outputs.append(out)
            return outputs[-1] if outputs else None

        # Start generation task
        gen_task = asyncio.create_task(gen())

        # Wait for some tokens to be generated
        while len(outputs) < 3:
            await asyncio.sleep(0.01)

        # Pause with abort mode
        await engine.pause_generation(mode="abort")

        # Wait for task to complete (should be aborted)
        final_output = await gen_task

        # Request should be finished (aborted)
        assert final_output is not None
        assert final_output.finished
        assert final_output.outputs[0].finish_reason == "abort"

        # Also test that new requests are blocked while paused, then resume
        assert await engine.is_paused()

        request_completed = False

        async def gen_blocked():
            nonlocal request_completed
            async for out in engine.generate(
                request_id="test-blocked",
                prompt=TEXT_PROMPT,
                sampling_params=SamplingParams(max_tokens=5),
            ):
                pass
            request_completed = True
            return out

        # Start a request (should block)
        gen_task2 = asyncio.create_task(gen_blocked())

        # Wait a bit - request should not have completed
        await asyncio.sleep(0.3)
        assert not request_completed, "Request should be blocked while paused"

        # Resume
        await engine.resume_generation()

        # Now request should complete
        final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
        assert request_completed
        assert final_output2.finished


@pytest.mark.asyncio
async def test_pause_wait():
    """Test that mode='wait' waits for in-flight requests to complete."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Start a request - use fewer tokens since wait mode waits for completion
        sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
        got_first_token = asyncio.Event()
        request_completed = False

        async def gen():
            nonlocal request_completed
            async for out in engine.generate(
                request_id="test-wait",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                got_first_token.set()
            request_completed = True
            return out

        # Start generation
        gen_task = asyncio.create_task(gen())

        # Wait for generation to start (event-driven)
        await asyncio.wait_for(got_first_token.wait(), timeout=30.0)

        # Pause with wait mode - should wait for request to finish
        await engine.pause_generation(mode="wait")

        # By now the request should be done (wait mode waits for completion)
        assert request_completed, "Request should have completed during wait"

        final_output = gen_task.result()
        assert final_output.finished
        # Should complete normally, not aborted
        assert final_output.outputs[0].finish_reason != "eos"


@pytest.mark.asyncio
async def test_pause_keep_single_request():
    """Test that mode='keep' freezes a single request and resumes with timing gap."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
        token_times: list[tuple[int, float]] = []
        pause_duration = 5.0
        pause_token_idx = 0

        async def generator_task():
            """Generate tokens and record timestamps."""
            async for output in engine.generate(
                request_id="test-keep-single",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                token_count = len(output.outputs[0].token_ids)
                token_times.append((token_count, time.monotonic()))
            return output

        async def controller_task():
            """Pause and resume the engine."""
            nonlocal pause_token_idx
            # Wait for some tokens (event-driven, handles slow token generation)
            while len(token_times) < 5:
                await asyncio.sleep(0.01)

            # Pause with keep mode
            await engine.pause_generation(mode="keep")
            pause_token_idx = len(token_times)

            # Sleep while paused
            await asyncio.sleep(pause_duration)

            # Resume
            await engine.resume_generation()

        # Run both tasks with timeout for slow generation
        gen_task = asyncio.create_task(generator_task())
        ctrl_task = asyncio.create_task(controller_task())

        final_output, _ = await asyncio.wait_for(
            asyncio.gather(gen_task, ctrl_task), timeout=60.0
        )

        # Request should complete with all tokens
        assert final_output.finished
        assert len(final_output.outputs[0].token_ids) == 30

        # Check the gap at the recorded pause index matches the pause duration
        pause_gap = (
            token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
        )
        assert pause_gap >= pause_duration * 0.8, (
            f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
        )


@pytest.mark.asyncio
async def test_pause_keep_multi_request():
    """Test that mode='keep' freezes multiple concurrent requests and all resume."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        num_requests = 3
        sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
        completed_requests: list[str] = []
        any_token_generated = asyncio.Event()

        async def gen_multi(request_id: str):
            async for out in engine.generate(
                request_id=request_id,
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                any_token_generated.set()
            completed_requests.append(request_id)
            return out

        # Start multiple requests
        tasks = [
            asyncio.create_task(gen_multi(f"req-multi-{i}"))
            for i in range(num_requests)
        ]

        # Wait for at least one token across any request (event-driven)
        await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)

        # Pause with keep mode
        await engine.pause_generation(mode="keep")

        # Wait while paused
        await asyncio.sleep(0.5)

        # Resume
        await engine.resume_generation()

        # All requests should complete
        results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)

        assert len(completed_requests) == num_requests
        for result in results:
            assert result.finished
            assert len(result.outputs[0].token_ids) == 10