test_async_llm.py 34.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import time
6
from contextlib import ExitStack
7
from unittest.mock import MagicMock
8
9
10
11

import pytest

from vllm import SamplingParams
12
from vllm.assets.image import ImageAsset
13
from vllm.config import VllmConfig
14
from vllm.engine.arg_utils import AsyncEngineArgs
15
from vllm.entrypoints.openai.chat_completion.protocol import (
16
17
    ChatCompletionRequest,
    ChatCompletionResponse,
18
19
)
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
20
21
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
22
from vllm.inputs import PromptType
23
from vllm.outputs import RequestOutput
24
from vllm.platforms import current_platform
25
from vllm.sampling_params import RequestOutputKind
26
from vllm.utils.torch_utils import set_default_torch_num_threads
27
from vllm.v1.engine.async_llm import AsyncLLM
28
29
30
31
32
33
from vllm.v1.metrics.loggers import (
    AggregatedLoggingStatLogger,
    LoggingStatLogger,
    PerEngineStatLoggerAdapter,
    PrometheusStatLogger,
)
34
35

if not current_platform.is_cuda():
36
    pytest.skip(reason="V1 currently only supported on CUDA.", allow_module_level=True)
37

38
39
40
41
TEXT_ENGINE_ARGS = AsyncEngineArgs(
    model="meta-llama/Llama-3.2-1B-Instruct",
    enforce_eager=True,
)
42

43
44
45
VISION_ENGINE_ARGS = AsyncEngineArgs(
    model="Qwen/Qwen2-VL-2B-Instruct", enforce_eager=True
)
46
47
48
49
50
51
52

TEXT_PROMPT = "Hello my name is Robert and"

VISION_PROMPT_TEMPLATE = (
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
    "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
    "What is in the image?<|im_end|>\n"
53
54
    "<|im_start|>assistant\n"
)
55
56
VISION_PROMPT = {
    "prompt": VISION_PROMPT_TEMPLATE,
57
    "multi_modal_data": {"image": ImageAsset("stop_sign").pil_image},
58
}
59
60


61
62
63
64
65
66
67
async def generate(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    output_kind: RequestOutputKind,
    max_tokens: int,
    n: int = 1,
68
69
    prompt_logprobs: int | None = None,
    cancel_after: int | None = None,
70
) -> tuple[int, str]:
71
72
73
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)

74
    count = 0
75
76
77
78
79
80
81
82
83
    sampling_params = SamplingParams(
        max_tokens=max_tokens,
        ignore_eos=True,
        output_kind=output_kind,
        temperature=0.5,
        seed=33,
        n=n,
        prompt_logprobs=prompt_logprobs,
    )
84
85
86
    async for out in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
87
        num_tokens = sum(len(output.token_ids) for output in out.outputs)
88
89
90
91
        if output_kind == RequestOutputKind.DELTA:
            count += num_tokens
        else:
            count = num_tokens
92

93
94
95
96
        if cancel_after is not None and count >= cancel_after:
            return count, request_id

        await asyncio.sleep(0.0)
97
98
99
100

    return count, request_id


101
@pytest.mark.parametrize(
102
103
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
104
105
106
107
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
108
@pytest.mark.asyncio
109
110
111
112
113
async def test_load(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
114
    with ExitStack() as after:
115
116
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
117
        after.callback(engine.shutdown)
118

119
        NUM_REQUESTS = 100
120
121
122
123
124
125
126
127
128
        NUM_EXPECTED_TOKENS = 10

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
129
130
131
132
133
                    generate(
                        engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS
                    )
                )
            )
134
135

        # Confirm that we got all the EXPECTED tokens from the requests.
136
        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION)
137
138
139
        for task in pending:
            task.cancel()
        for task in done:
140
            num_generated_tokens, request_id = await task
141
142
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} but "
143
144
                f"expected {NUM_EXPECTED_TOKENS}"
            )
145
146
147
148

        assert not engine.output_processor.has_unfinished_requests()


149
@pytest.mark.parametrize(
150
151
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
152
153
154
155
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
156
@pytest.mark.asyncio
157
158
159
160
161
async def test_abort(
    output_kind: RequestOutputKind,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
162
    with ExitStack() as after:
163
164
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
165
        after.callback(engine.shutdown)
166
167
168

        NUM_REQUESTS = 100
        NUM_EXPECTED_TOKENS = 100
169
        NUM_EXPECTED_TOKENS_LONG = 50000
170
        REQUEST_IDS_TO_ABORT = range(1, 100, 10)
171
        PARALLEL_SAMPLE_REQ_IDS = range(1, 100, 15)
172
173
174
175

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
176
        tasks: list[asyncio.Task] = []
177
        for idx, request_id in enumerate(request_ids):
178
179
180
181
182
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
183
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
184
185
            tasks.append(
                asyncio.create_task(
186
187
188
                    generate(engine, request_id, prompt, output_kind, max_tokens, n)
                )
            )
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

        # API server cancels requests when they disconnect.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises(asyncio.CancelledError):
                    await task
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await task
204
205
206
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
207
                    f"{request_id} generated {num_generated_tokens} but "
208
209
                    f"expected {expected_tokens}"
                )
210

211
        # Make sure all aborted requests were really aborted.
212
213
214
215
216
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(
217
218
            generate(engine, request_id, prompt, output_kind, NUM_EXPECTED_TOKENS)
        )
219
220
221
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()
222
223


224
@pytest.mark.parametrize(
225
226
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
227
@pytest.mark.asyncio
228
229
async def test_multi_abort(output_kind: RequestOutputKind):
    with ExitStack() as after:
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        NUM_REQUESTS = 50
        NUM_EXPECTED_TOKENS = 100
        NUM_EXPECTED_TOKENS_LONG = 50000
        REQUEST_IDS_TO_ABORT = [5, 10, 15, 20, 25]
        PARALLEL_SAMPLE_REQ_IDS = [5, 15, 30, 35]

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
245
246
247
248
249
            max_tokens = (
                NUM_EXPECTED_TOKENS_LONG
                if (idx in REQUEST_IDS_TO_ABORT)
                else NUM_EXPECTED_TOKENS
            )
250
251
252
            n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
            tasks.append(
                asyncio.create_task(
253
254
255
256
257
                    generate(
                        engine, request_id, TEXT_PROMPT, output_kind, max_tokens, n
                    )
                )
            )
258
259
260
261
262
263

        # Let requests start
        await asyncio.sleep(0.5)

        # Use multi-abort to abort multiple requests at once
        abort_request_ids = [request_ids[i] for i in REQUEST_IDS_TO_ABORT]
264
        await engine.abort(abort_request_ids, internal=False)
265
266
267
268
269
270
271
272

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # Verify results
        for idx, result in enumerate(results):
            if idx in REQUEST_IDS_TO_ABORT:
                # Aborted requests should return partial results
273
274
275
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed with partial results"
                )
276
277
278
                num_generated_tokens, request_id = result
                # Should have generated some tokens before abort
                assert num_generated_tokens > 0, (
279
280
                    f"Aborted request {request_id} should have generated some tokens"
                )
281
282
            else:
                # Non-aborted requests should complete normally
283
284
285
                assert isinstance(result, tuple), (
                    f"Request {idx} should have completed successfully"
                )
286
287
288
289
290
                num_generated_tokens, request_id = result
                n = 3 if idx in PARALLEL_SAMPLE_REQ_IDS else 1
                expected_tokens = NUM_EXPECTED_TOKENS * n
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but "
291
292
                    f"expected {expected_tokens}"
                )
293
294
295
296
297

        # Make sure all aborted requests were cleaned up
        assert not engine.output_processor.has_unfinished_requests()


298
@pytest.mark.parametrize("n", [1, 3])
299
300
301
302
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
303
@pytest.mark.asyncio
304
305
306
307
308
async def test_finished_flag(
    n: int,
    engine_args: AsyncEngineArgs,
    prompt: PromptType,
):
309
    with ExitStack() as after:
310
311
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
312
313
        after.callback(engine.shutdown)

314
315
316
317
318
319
320
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
            n=n,
        )
321
322
        outputs = [
            out
323
324
325
            async for out in engine.generate(
                request_id="request-33", prompt=prompt, sampling_params=sampling_params
            )
326
327
328
329
330
        ]

        # Assert only the last output has the finished flag set
        assert all(not out.finished for out in outputs[:-1])
        assert outputs[-1].finished
331
332


333
334
335
336
337
@pytest.mark.parametrize(
    "engine_args,prompt",
    [(TEXT_ENGINE_ARGS, TEXT_PROMPT), (VISION_ENGINE_ARGS, VISION_PROMPT)],
)
@pytest.mark.asyncio
338
async def test_mid_stream_cancellation(
339
    engine_args: AsyncEngineArgs, prompt: PromptType
340
):
341
    """Test that requests can be cancelled mid-stream."""
342
    with ExitStack() as after:
343
344
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(engine_args)
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        after.callback(engine.shutdown)

        NUM_REQUESTS = 100
        NUM_TOKENS = 1000
        NUM_EXPECTED_TOKENS = 20

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests that will be cancelled mid-stream
        tasks = []
        for request_id in request_ids:
            tasks.append(
                asyncio.create_task(
                    generate(
                        engine,
                        request_id,
                        prompt,
                        RequestOutputKind.DELTA,
                        NUM_TOKENS,
                        cancel_after=NUM_EXPECTED_TOKENS,
365
366
367
                    )
                )
            )
368
369
370
371
372
373
374
375

        # Wait for all tasks to complete
        results = await asyncio.gather(*tasks)

        # Verify all tasks were cancelled at the expected point
        for num_generated_tokens, request_id in results:
            assert num_generated_tokens == NUM_EXPECTED_TOKENS, (
                f"{request_id} generated {num_generated_tokens} tokens but "
376
377
                f"expected to cancel after {NUM_EXPECTED_TOKENS}"
            )
378
379
380
381
382
383
384

        # Make sure no requests are left hanging
        assert not engine.output_processor.has_unfinished_requests()

        # Confirm we can reuse the request id after the cancellations.
        request_id = request_ids[0]
        task = asyncio.create_task(
385
386
387
388
            generate(
                engine, request_id, prompt, RequestOutputKind.DELTA, NUM_EXPECTED_TOKENS
            )
        )
389
390
391
392
393
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
        assert not engine.output_processor.has_unfinished_requests()


394
395
396
397
398
399
class MockLoggingStatLogger(LoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
        super().__init__(vllm_config, engine_index)
        self.log = MagicMock()


400
401
402
403
404
405
class MockAggregatedStatLogger(AggregatedLoggingStatLogger):
    def __init__(self, vllm_config: VllmConfig, engine_indexes: list[int]):
        super().__init__(vllm_config, engine_indexes)
        self.log = MagicMock()


406
407
408
409
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
    """Test that we can customize the loggers.
    If a customized logger is provided at the init, it should
410
    be added to the default loggers.
411
412
    """

413
    with ExitStack() as after:
414
415
416
417
418
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger],
            )
419
420
421
422
        after.callback(engine.shutdown)

        await engine.do_log_stats()

423
424
425
426
427
428
429
430
431
432
433
434
        stat_loggers = engine.logger_manager.stat_loggers
        assert (
            len(stat_loggers) == 3
        )  # MockLoggingStatLogger + LoggingStatLogger +  Promethus Logger
        print(f"{stat_loggers=}")
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        assert isinstance(stat_loggers[1], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[1].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[2], PrometheusStatLogger)


@pytest.mark.asyncio
435
async def test_customize_aggregated_loggers():
436
437
438
439
    """Test that we can customize the aggregated loggers.
    If a customized logger is provided at the init, it should
    be added to the default loggers.
    """
440
    with ExitStack() as after:
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(
                TEXT_ENGINE_ARGS,
                stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
            )
        after.callback(engine.shutdown)

        await engine.do_log_stats()

        stat_loggers = engine.logger_manager.stat_loggers
        assert len(stat_loggers) == 4
        #  MockLoggingStatLogger + MockAggregatedStatLogger
        # + LoggingStatLogger + PrometheusStatLogger
        stat_loggers[0].per_engine_stat_loggers[0].log.assert_called_once()
        stat_loggers[1].log.assert_called_once()
        assert isinstance(stat_loggers[2], PerEngineStatLoggerAdapter)
        assert isinstance(stat_loggers[2].per_engine_stat_loggers[0], LoggingStatLogger)
        assert isinstance(stat_loggers[3], PrometheusStatLogger)
459
460
461


@pytest.mark.asyncio(scope="module")
462
463
async def test_dp_rank_argument():
    with ExitStack() as after:
464
465
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
466
467
        after.callback(engine.shutdown)

468
469
470
471
472
473
        sampling_params = SamplingParams(
            max_tokens=100,
            output_kind=RequestOutputKind.DELTA,
            temperature=1.0,
            seed=33,
        )
474
475

        # Test with valid DP rank.
476
477
478
479
480
481
        async for _ in engine.generate(
            request_id="request-34",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
            data_parallel_rank=0,
        ):
482
483
484
485
            pass

        # Test with out-of-range DP rank.
        with pytest.raises(ValueError):
486
487
488
489
490
491
            async for _ in engine.generate(
                request_id="request-35",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
                data_parallel_rank=1,
            ):
492
                pass
493
494


495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
@pytest.mark.asyncio(scope="module")
async def test_header_dp_rank_argument():
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        MODEL_NAME = "test-model"
        BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]

        # Create models first
        models = OpenAIServingModels(
            engine_client=engine,
            base_model_paths=BASE_MODEL_PATHS,
        )

        # Create serving chat instance
        serving_chat = OpenAIServingChat(
            engine_client=engine,
            models=models,
            response_role="assistant",
            chat_template=None,
            chat_template_content_format="auto",
            request_logger=None,
        )
        # Create a chat completion request
        req = ChatCompletionRequest(
            model=MODEL_NAME,
            messages=[{"role": "user", "content": TEXT_PROMPT}],
            max_tokens=100,
            temperature=1.0,
            seed=33,
        )
        # Test 1: Valid DP rank (0)
        mock_raw_request = MagicMock()
        mock_raw_request.headers = {"X-data-parallel-rank": "0"}
        mock_raw_request.state = MagicMock()

        # Should succeed with valid rank
        response = await serving_chat.create_chat_completion(req, mock_raw_request)
        assert isinstance(response, ChatCompletionResponse), (
            "Expected a ChatCompletionResponse for valid DP rank"
        )

        # Test 2: Out-of-range DP rank (1)
        mock_raw_request.headers = {"X-data-parallel-rank": "1"}

542
543
544
        # should raise ValueError for out-of-range rank
        with pytest.raises(ValueError):
            await serving_chat.create_chat_completion(req, mock_raw_request)
545
546


547
@pytest.mark.asyncio
548
async def test_check_health():
549
550
551
552
553
554
555
    """Test that check_health returns normally for healthy engine
    and raises EngineDeadError when the engine is dead.
    """
    from unittest.mock import patch

    from vllm.v1.engine.exceptions import EngineDeadError

556
    with ExitStack() as after:
557
558
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
559
560
561
562
563
564
        after.callback(engine.shutdown)

        # Test 1: Healthy engine should not raise any exception
        await engine.check_health()

        # Test 2: Mock the errored property to simulate a dead engine
565
566
567
568
569
570
571
572
        with (
            patch.object(
                type(engine),
                "errored",
                new_callable=lambda: property(lambda self: True),
            ),
            pytest.raises(EngineDeadError),
        ):
573
574
575
576
            await engine.check_health()

        # Test 3: Verify healthy engine still works after mock
        await engine.check_health()
577
578
579


@pytest.mark.parametrize(
580
581
    "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
582
@pytest.mark.asyncio
583
async def test_abort_final_output(output_kind: RequestOutputKind):
584
585
    """Test that abort() returns a final output with correct information."""

586
    with ExitStack() as after:
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "test-abort-final-output"

        # Start a long-running request
        sampling_params = SamplingParams(
            max_tokens=3000,  # Long enough to allow abort
            ignore_eos=True,
            output_kind=output_kind,
            temperature=0.5,
            seed=42,
        )

        outputs: list[RequestOutput] = []
        generated = asyncio.create_task(
604
605
            collect_outputs(engine, request_id, TEXT_PROMPT, sampling_params, outputs)
        )
606
607
608
609
610

        # Let it generate some tokens
        await asyncio.sleep(0.5)

        # Abort the request
611
        await engine.abort(request_id, internal=False)
612
613
614
615
616
617
618
619
620
621
622
623
624

        # Wait for generation to complete and return final output
        final_output = await generated

        # Verify we got a final output
        assert final_output is not None
        assert final_output.finished
        assert len(final_output.outputs) == 1

        assert final_output.outputs[0].finish_reason == "abort"
        assert final_output.outputs[0].stop_reason is None

        # Verify num_cached_tokens is set correctly
625
        assert hasattr(final_output, "num_cached_tokens")
626
627
628
629
630
        assert final_output.num_cached_tokens >= 0

        # If we got intermediate outputs, verify they are consistent
        if output_kind == RequestOutputKind.DELTA:
            # For DELTA, sum all intermediate tokens should <= final tokens
631
            token_count = sum(len(output.outputs[0].token_ids) for output in outputs)
632
            assert token_count > 0
633
634
635
            # This would ordinarily be 0, but could end up > 0 if the
            # final abort is coalesced with another chunk in the output queue.
            assert len(final_output.outputs[0].token_ids) >= 0
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        else:
            # For FINAL_ONLY, we should only get the final output
            assert len(outputs) == 0
            assert len(final_output.outputs[0].token_ids) > 0

        assert not engine.output_processor.has_unfinished_requests()


async def collect_outputs(
    engine: AsyncLLM,
    request_id: str,
    prompt: PromptType,
    sampling_params: SamplingParams,
    outputs_list: list[RequestOutput],
650
) -> RequestOutput | None:
651
    """Helper to collect outputs and return the final one."""
652
    final_output: RequestOutput | None = None
653
654
655
    async for output in engine.generate(
        request_id=request_id, prompt=prompt, sampling_params=sampling_params
    ):
656
657
658
659
        if not output.finished:
            outputs_list.append(output)
        final_output = output
    return final_output
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705


# =============================================================================
# Pause/Resume Tests
# =============================================================================


@pytest.mark.asyncio
async def test_pause_resume_basic():
    """Test basic pause/resume flag behavior and idempotency.

    Tests:
    - pause_generation sets the paused flag
    - resume_generation clears the paused flag
    - calling pause when already paused is a no-op
    - calling resume when not paused is safe
    - all pause modes work with no requests in flight
    - rapid pause/resume cycles don't break the engine
    """
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Initially not paused
        assert not await engine.is_paused()

        # Resume when not paused should be safe
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Pause sets flag
        await engine.pause_generation(mode="abort")
        assert await engine.is_paused()

        # Pause when already paused is a no-op
        await engine.pause_generation(mode="abort")
        assert await engine.is_paused()

        # Resume clears flag
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Test all modes with no requests in flight
        for mode in ("abort", "wait", "keep"):
            await engine.pause_generation(mode=mode)
706
            assert await engine.is_paused()
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
            await engine.resume_generation()
            assert not await engine.is_paused()

        # Concurrent pause/resume race conditions - should not deadlock or raise
        await asyncio.gather(
            engine.pause_generation(mode="abort"),
            engine.resume_generation(),
            engine.pause_generation(mode="abort"),
            engine.resume_generation(),
        )

        # Ensure we end in a known state
        await engine.resume_generation()
        assert not await engine.is_paused()

        # Engine should still work after all cycles
        sampling_params = SamplingParams(max_tokens=5)
        async for out in engine.generate(
            request_id="post-cycles",
            prompt=TEXT_PROMPT,
            sampling_params=sampling_params,
        ):
            pass
        assert out.finished


@pytest.mark.asyncio
async def test_pause_abort():
    """Test that mode='abort' aborts in-flight requests immediately."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Start a long-running request
        sampling_params = SamplingParams(max_tokens=1000, ignore_eos=True)
        outputs: list[RequestOutput] = []

        async def gen():
            async for out in engine.generate(
                request_id="test-abort-pause",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                outputs.append(out)
            return outputs[-1] if outputs else None

        # Start generation task
        gen_task = asyncio.create_task(gen())

        # Wait for some tokens to be generated
        while len(outputs) < 3:
            await asyncio.sleep(0.01)

        # Pause with abort mode
        await engine.pause_generation(mode="abort")

        # Wait for task to complete (should be aborted)
        final_output = await gen_task

        # Request should be finished (aborted)
        assert final_output is not None
        assert final_output.finished
        assert final_output.outputs[0].finish_reason == "abort"

        # Also test that new requests are blocked while paused, then resume
        assert await engine.is_paused()

        request_completed = False

        async def gen_blocked():
            nonlocal request_completed
            async for out in engine.generate(
                request_id="test-blocked",
                prompt=TEXT_PROMPT,
                sampling_params=SamplingParams(max_tokens=5),
            ):
                pass
            request_completed = True
            return out

        # Start a request (should block)
        gen_task2 = asyncio.create_task(gen_blocked())

        # Wait a bit - request should not have completed
        await asyncio.sleep(0.3)
        assert not request_completed, "Request should be blocked while paused"

        # Resume
        await engine.resume_generation()

        # Now request should complete
        final_output2 = await asyncio.wait_for(gen_task2, timeout=10.0)
        assert request_completed
        assert final_output2.finished


804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
@pytest.mark.asyncio
async def test_pause_then_abort_queued_request():
    """Test that aborting a request that was submitted while paused (in
    _paused_adds_queue) aborts it and notifies the client; the request does
    not run after resume.
    """
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        request_id = "abort-queued-request"
        sampling_params = SamplingParams(max_tokens=20, ignore_eos=True)
        outputs: list[RequestOutput] = []

        # Pause first so the next add goes to _paused_adds_queue
        await engine.pause_generation(mode="keep")
        assert await engine.is_paused()

        async def gen():
            async for out in engine.generate(
                request_id=request_id,
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                outputs.append(out)
            return outputs[-1] if outputs else None

        gen_task = asyncio.create_task(gen())

        # Give the request time to reach the engine and sit in _paused_adds_queue
        await asyncio.sleep(0.2)

        # Abort the queued request
        await engine.abort(request_id, internal=False)

        # Resume so the engine can process and deliver the abort output
        await engine.resume_generation()

        final_output = await asyncio.wait_for(gen_task, timeout=10.0)
        assert final_output is not None
        assert final_output.finished
        assert final_output.outputs[0].finish_reason == "abort"
        # Request was never run, so no tokens
        assert len(final_output.outputs[0].token_ids) == 0


851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
@pytest.mark.asyncio
async def test_pause_wait():
    """Test that mode='wait' waits for in-flight requests to complete."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        # Start a request - use fewer tokens since wait mode waits for completion
        sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
        got_first_token = asyncio.Event()
        request_completed = False

        async def gen():
            nonlocal request_completed
            async for out in engine.generate(
                request_id="test-wait",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                got_first_token.set()
            request_completed = True
            return out

        # Start generation
        gen_task = asyncio.create_task(gen())

        # Wait for generation to start (event-driven)
        await asyncio.wait_for(got_first_token.wait(), timeout=30.0)

        # Pause with wait mode - should wait for request to finish
        await engine.pause_generation(mode="wait")

        # By now the request should be done (wait mode waits for completion)
        assert request_completed, "Request should have completed during wait"

        final_output = gen_task.result()
        assert final_output.finished
        # Should complete normally, not aborted
        assert final_output.outputs[0].finish_reason != "eos"


@pytest.mark.asyncio
async def test_pause_keep_single_request():
    """Test that mode='keep' freezes a single request and resumes with timing gap."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        sampling_params = SamplingParams(max_tokens=30, ignore_eos=True)
        token_times: list[tuple[int, float]] = []
        pause_duration = 5.0
        pause_token_idx = 0

        async def generator_task():
            """Generate tokens and record timestamps."""
            async for output in engine.generate(
                request_id="test-keep-single",
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                token_count = len(output.outputs[0].token_ids)
                token_times.append((token_count, time.monotonic()))
            return output

        async def controller_task():
            """Pause and resume the engine."""
            nonlocal pause_token_idx
            # Wait for some tokens (event-driven, handles slow token generation)
            while len(token_times) < 5:
                await asyncio.sleep(0.01)

            # Pause with keep mode
            await engine.pause_generation(mode="keep")
            pause_token_idx = len(token_times)

            # Sleep while paused
            await asyncio.sleep(pause_duration)

            # Resume
            await engine.resume_generation()

        # Run both tasks with timeout for slow generation
        gen_task = asyncio.create_task(generator_task())
        ctrl_task = asyncio.create_task(controller_task())

        final_output, _ = await asyncio.wait_for(
            asyncio.gather(gen_task, ctrl_task), timeout=60.0
        )

        # Request should complete with all tokens
        assert final_output.finished
        assert len(final_output.outputs[0].token_ids) == 30

        # Check the gap at the recorded pause index matches the pause duration
        pause_gap = (
            token_times[pause_token_idx][1] - token_times[pause_token_idx - 1][1]
        )
        assert pause_gap >= pause_duration * 0.8, (
            f"Expected gap of ~{pause_duration}s after pause, got {pause_gap:.3f}s"
        )


@pytest.mark.asyncio
async def test_pause_keep_multi_request():
    """Test that mode='keep' freezes multiple concurrent requests and all resume."""
    with ExitStack() as after:
        with set_default_torch_num_threads(1):
            engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
        after.callback(engine.shutdown)

        num_requests = 3
        sampling_params = SamplingParams(max_tokens=10, ignore_eos=True)
        completed_requests: list[str] = []
        any_token_generated = asyncio.Event()

        async def gen_multi(request_id: str):
            async for out in engine.generate(
                request_id=request_id,
                prompt=TEXT_PROMPT,
                sampling_params=sampling_params,
            ):
                any_token_generated.set()
            completed_requests.append(request_id)
            return out

        # Start multiple requests
        tasks = [
            asyncio.create_task(gen_multi(f"req-multi-{i}"))
            for i in range(num_requests)
        ]

        # Wait for at least one token across any request (event-driven)
        await asyncio.wait_for(any_token_generated.wait(), timeout=30.0)

        # Pause with keep mode
        await engine.pause_generation(mode="keep")

        # Wait while paused
        await asyncio.sleep(0.5)

        # Resume
        await engine.resume_generation()

        # All requests should complete
        results = await asyncio.wait_for(asyncio.gather(*tasks), timeout=60.0)

        assert len(completed_requests) == num_requests
        for result in results:
            assert result.finished
            assert len(result.outputs[0].token_ids) == 10