test_output_processor.py 48 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import math
5
import time
6
7
8

import pytest

9
10
11
12
13
14
15
from tests.v1.engine.utils import (
    NUM_PROMPT_LOGPROBS_UNDER_TEST,
    NUM_SAMPLE_LOGPROBS_UNDER_TEST,
    STOP_STRINGS,
    DummyOutputProcessorTestVectors,
    MockEngineCore,
)
16
from vllm import PoolingParams
17
from vllm.logprobs import PromptLogprobs, SampleLogprobs
18
from vllm.lora.request import LoRARequest
19
from vllm.outputs import CompletionOutput, RequestOutput
20
from vllm.sampling_params import RequestOutputKind, SamplingParams
21
from vllm.transformers_utils.tokenizer import AnyTokenizer
22
23
24
25
26
27
28
from vllm.v1.engine import (
    EngineCoreEvent,
    EngineCoreEventType,
    EngineCoreOutputs,
    EngineCoreRequest,
    FinishReason,
)
29
from vllm.v1.engine.output_processor import OutputProcessor, RequestOutputCollector
30
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
31

32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _ref_convert_id_to_token(
    tokenizer: AnyTokenizer,
    token_id: int,
) -> str:
    """Reference impl of logprobs detokenization.

    Args:
      tokenizer: tokenizer used by the model under test
      token_id: convert this token id

    Returns:
      String representation of input token id
    """
46
    return tokenizer.decode([token_id]) or ""
47
48
49


@pytest.mark.parametrize(
50
51
52
53
54
55
56
    "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
def test_incremental_detokenization(
    request_output_kind: RequestOutputKind, dummy_test_vectors
):
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
    engine_core = MockEngineCore(tokens_list=dummy_test_vectors.generation_tokens)
57
58
59

    # Make N requests.
    requests = [
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
        EngineCoreRequest(
            request_id=f"request-{idx}",
            prompt_token_ids=prompt_tokens,
            mm_features=None,
            eos_token_id=None,
            arrival_time=0,
            lora_request=None,
            cache_salt=None,
            data_parallel_rank=None,
            sampling_params=SamplingParams(
                skip_special_tokens=False,
                spaces_between_special_tokens=False,
                output_kind=request_output_kind,
                stop=[],
                include_stop_str_in_output=False,
            ),
            pooling_params=None,
        )
78
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
79
80
81
    ]

    # Add requests to the detokenizer.
82
83
    for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
        output_processor.add_request(request, prompt)
84
85
86
87
88
89
90
91
92
93

    gen_strings = {}
    gen_tokens = {}
    while True:
        # Mock output from the EngineCore.
        outputs = engine_core.get_outputs()
        if len(outputs) == 0:
            break

        # Step the Detokenizer.
94
        processed_outputs = output_processor.process_outputs(outputs)
95
96
        request_outputs = processed_outputs.request_outputs
        requests_to_abort = processed_outputs.reqs_to_abort
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        assert len(requests_to_abort) == 0

        # Update tracking.
        for request_output in request_outputs:
            request_id = request_output.request_id
            new_text = request_output.outputs[0].text
            new_tokens = request_output.outputs[0].token_ids
            if request_id not in gen_strings:
                gen_strings[request_id] = new_text
                gen_tokens[request_id] = new_tokens
            else:
                gen_strings[request_id] += new_text
                gen_tokens[request_id].extend(new_tokens)

    # Confirmed tracked values matches what we expected.
    for idx, (ref_gen_str, ref_gen_toks) in enumerate(
113
114
        zip(dummy_test_vectors.generation_strings, dummy_test_vectors.generation_tokens)
    ):
115
116
117
118
119
120
        gen_str = gen_strings[f"request-{idx}"]
        gen_toks = gen_tokens[f"request-{idx}"]

        assert gen_str == ref_gen_str, f"{gen_str=}, {ref_gen_str=}"
        assert gen_toks == ref_gen_toks, f"{gen_toks=}, {ref_gen_toks=}"

121
122
    assert output_processor.get_num_unfinished_requests() == 0
    assert not output_processor.has_unfinished_requests()
123
124


125
def _validate_logprobs(
126
    gen_tokens: dict[str, list[int]],
127
128
    gen_logprobs: dict[str, SampleLogprobs | None],
    gen_prompt_logprobs: dict[str, PromptLogprobs | None],
129
    gen_cumulative_logprob: dict[str, float],
130
    dtv: DummyOutputProcessorTestVectors,
131
    request_id_list: list[str],
132
133
    num_sample_logprobs: int | None,
    num_prompt_logprobs: int | None,
134
135
136
137
138
139
140
141
142
143
144
) -> None:
    for req_idx, req_id in enumerate(request_id_list):
        new_tokens = gen_tokens[req_id]
        logprobs = gen_logprobs[req_id]
        prompt_logprobs = gen_prompt_logprobs[req_id]
        cumulative_logprob = gen_cumulative_logprob[req_id]
        prompt_token_ids = dtv.prompt_tokens[req_idx]
        ref_logprobs = dtv.generation_logprobs[req_idx]
        ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
        if num_sample_logprobs is not None:
            # Validate sample logprobs
145
146
147
148
149
            assert logprobs is not None, (
                f"Request {req_id} requires sample"
                " logprobs but sample logprobs are"
                " None."
            )
150
151
152
153
154
155
156
157
158
159
            # Require num sampled tokens to match num
            # sampled logprobs - especially important
            # to check since the detokenizer can cause
            # a request to finish early due to a stop
            # string being hit
            num_new_tokens = len(new_tokens)
            len_sample_logprobs = len(logprobs)
            assert num_new_tokens == len_sample_logprobs, (
                f"Request {req_id} has {num_new_tokens}"
                " completion tokens but has"
160
161
                f" {len_sample_logprobs} sample logprobs."
            )
162
            ref_cumulative_logprob = 0.0
163
164
165
            for idx, (sampled_token, pos_logprob_dict) in enumerate(
                zip(new_tokens, logprobs)
            ):
166
167
168
169
                # Break out the reference log probability value &
                # logprob token id tensors associated with this
                # position in the completion. Also break out the
                # sampled token ranks
170
171
172
                (ref_pos_logprob_toks, ref_pos_logprob_vals, ref_sampled_token_rank) = (
                    ref_logprobs[idx]
                )
173
174
175
176
177
                # For each position in the completion sequence,
                # ensure the actual sampled token is among the
                # logprobs
                assert sampled_token in pos_logprob_dict, (
                    f"Sampled token {sampled_token} not"
178
179
                    f" present in logprob at index {idx}"
                )
180
181
182

                # Validate number of sample logprobs
                num_lp_toks = len(pos_logprob_dict)
183
184
185
186
187
188
189
190
191
192
193
                assert (
                    num_lp_toks == num_sample_logprobs
                    or num_lp_toks == num_sample_logprobs + 1
                ), (
                    "Valid numbers of sample logprobs are"
                    f" {num_sample_logprobs} or"
                    f" {num_sample_logprobs + 1} but"
                    f" {num_lp_toks} logprobs found at"
                    f" position {idx}. Logprobs dict:"
                    f" {pos_logprob_dict}"
                )
194
195
196
197

                # Validate sampled token logprob rank
                smp_lp = pos_logprob_dict[sampled_token]
                smp_lp_rank = smp_lp.rank
198
                assert ref_sampled_token_rank == smp_lp_rank, (
199
200
201
202
                    "Sampled token logprob rank"
                    f" {smp_lp_rank} does not match"
                    " correct value"
                    f" {ref_sampled_token_rank}"
203
204
                    f" in Logprob {smp_lp}"
                )
205
206
207
208
209
210
211
212
213
214
215
216
217

                # Validate that the logprob processor yields
                # the correct log probabilities and valid
                # rankings
                rank_one_appears = False
                for jdx in range(1, len(ref_pos_logprob_toks)):
                    # Iterate over the (logprob val,logprob tok id)
                    # pairs expected by the test fixture at this
                    # position in the completion.
                    ref_lp_val = ref_pos_logprob_vals[jdx]
                    ref_tok_id = ref_pos_logprob_toks[jdx]
                    assert ref_tok_id in pos_logprob_dict, (
                        f"Expected token {ref_tok_id} to be"
218
219
                        f" in logprob dict but it is not."
                    )
220
221
222
223
224
225
226
227
228

                    # Extract actually-generated logprob
                    # info
                    lp = pos_logprob_dict[ref_tok_id]
                    lp_val = lp.logprob
                    lp_rank = lp.rank

                    # A "top" (rank 1) logprob must be
                    # present
229
                    rank_one_appears = True if lp_rank == 1 else rank_one_appears
230
231

                    # Rank must be >= 1
232
233
234
235
236
                    assert lp_rank >= 1, (
                        f"Logprob {lp} has invalid"
                        f" rank {lp_rank} < 1."
                        f" Logprob dict: {pos_logprob_dict}"
                    )
237
238
239
240
241
242

                    # Validate log probability
                    assert math.isclose(lp_val, ref_lp_val), (
                        f"Token id {ref_tok_id} appears in logprobs dict"
                        f" at position {idx} in completion with log"
                        f" probability {lp_val} but {ref_lp_val} was"
243
244
                        f" expected. Logprob: {lp}"
                    )
245

246
247
248
249
250
                assert rank_one_appears, (
                    f"No Logprob has rank 1"
                    " in the following Logprob"
                    f" dict: {pos_logprob_dict}"
                )
251
252
253
254
255
256

                # Validate logprobs detokenization
                for lp_tok in pos_logprob_dict:
                    # Confirm that sample logprob decoded token matches
                    # the logprob token id at this sequence position
                    decoded_token = pos_logprob_dict[lp_tok].decoded_token
257
                    ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, lp_tok)
258
259
260
261
                    assert decoded_token == ref_decoded_token, (
                        f"Sampled logprob token id {lp_tok} decodes to"
                        f" {ref_decoded_token} but Logprob decoded"
                        f" token is {decoded_token} instead"
262
263
                        f" (at position {idx})"
                    )
264

265
                ref_cumulative_logprob += pos_logprob_dict[sampled_token].logprob
266
267
268
269
270
271
272
273
274
275
276
277
            # Assert that cumulative logprobs are correct
            assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
        else:
            # Sample logprobs disabled for this request
            assert logprobs is None
            assert cumulative_logprob is None

        if num_prompt_logprobs is not None:
            # Validate prompt logprobs
            assert prompt_logprobs is not None, (
                f"Request {req_id} requires prompt"
                " logprobs but prompt logprobs are"
278
279
                " None."
            )
280
281
282
283
284
285
286
            # Require num prompt tokens to match num
            # prompt logprobs
            num_prompt_tokens = len(prompt_token_ids)
            len_prompt_logprobs = len(prompt_logprobs)
            assert num_prompt_tokens == len_prompt_logprobs, (
                f"Request {req_id} has {num_prompt_tokens}"
                " prompt tokens but has"
287
288
                f" {len_prompt_logprobs} prompt logprobs."
            )
289
290
291
292
293
            # First prompt logprob is None
            first_plp_dict = prompt_logprobs[0]
            assert first_plp_dict is None, (
                f"Request {req_id} first prompt logprob"
                f" should be None but has following value"
294
295
                f" instead: {first_plp_dict}"
            )
296
297
298
            # Break out the reference prompt log prob value &
            # logprob token id matrices for the whole prompt.
            # Also break out the prompt token rank vector
299
300
301
302
303
            (
                ref_prompt_logprob_toks,
                ref_prompt_logprob_vals,
                ref_prompt_token_ranks,
            ) = ref_prompt_logprobs
304
            for idx, (prompt_token, pos_logprob_dict) in enumerate(
305
306
                zip(prompt_token_ids[1:], prompt_logprobs[1:])
            ):
307
308
309
                # Break out the reference prompt log prob value
                # vector, prompt logprob token id vector, and
                # prompt token rank at the current position.
310
311
312
313
314
315
316
317
318
                (
                    ref_pos_prompt_logprob_toks,
                    ref_pos_prompt_logprob_vals,
                    ref_pos_prompt_token_rank,
                ) = (
                    ref_prompt_logprob_toks[idx, :],
                    ref_prompt_logprob_vals[idx, :],
                    ref_prompt_token_ranks[idx],
                )
319
320
321
322
323

                # For each position in the prompt sequence,
                # ensure the actual prompt token is among the
                # logprobs
                assert prompt_token in pos_logprob_dict, (
324
325
                    f"Prompt token {prompt_token} not present in logprob at index {idx}"
                )
326
327
                # Validate number of prompt logprobs
                num_plp_toks = len(pos_logprob_dict)
328
329
330
331
332
333
334
335
336
337
338
                assert (
                    num_plp_toks == num_prompt_logprobs
                    or num_plp_toks == num_prompt_logprobs + 1
                ), (
                    "Valid numbers of prompt logprobs are"
                    f" {num_prompt_logprobs} or"
                    f" {num_prompt_logprobs + 1} but"
                    f" {num_plp_toks} logprobs found at"
                    f" position {idx}. Logprobs dict:"
                    f" {pos_logprob_dict}"
                )
339
340
341
342
343

                # Validate prompt token logprob rank
                prmpt_tok_lp = pos_logprob_dict[prompt_token]
                prmpt_tok_lp_rank = prmpt_tok_lp.rank
                ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
344
                assert ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank, (
345
346
347
348
                    "Prompt token logprob rank"
                    f" {prmpt_tok_lp_rank} does not match"
                    " correct value"
                    f" {ref_prmpt_tok_lp_rank}"
349
350
                    f" in Logprob {prmpt_tok_lp}"
                )
351
352
353
354
355
356
357
358
359
360
361
362
363

                # Validate that the logprob processor yields
                # the correct prompt log probs and valid
                # rankings
                rank_one_appears = False
                for jdx in range(1, len(ref_pos_prompt_logprob_toks)):
                    # Iterate over the (logprob val,logprob tok id)
                    # pairs expected by the test fixture at this
                    # position in the completion.
                    ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx])
                    ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
                    assert ref_tok_id in pos_logprob_dict, (
                        f"Expected token {ref_tok_id} to be"
364
365
                        f" in logprob dict but it is not."
                    )
366
367
368
369
370
371
372
373
374

                    # Extract actually-generated logprob
                    # info
                    plp = pos_logprob_dict[ref_tok_id]
                    plp_val = plp.logprob
                    plp_rank = plp.rank

                    # A "top" (rank 1) logprob must be
                    # present
375
                    rank_one_appears = True if plp_rank == 1 else rank_one_appears
376
377
378
379
380

                    # Rank must be >= 1
                    assert plp_rank >= 1, (
                        f"Logprob {plp} has invalid"
                        f" rank {plp_rank} < 1."
381
382
                        f" Logprob dict: {pos_logprob_dict}"
                    )
383
384
385
386
387
388

                    # Validate log probability
                    assert math.isclose(plp_val, ref_plp_val), (
                        f"Token id {ref_tok_id} appears in logprobs dict"
                        f" at position {idx} in completion with log"
                        f" probability {plp_val} but {ref_plp_val} was"
389
390
                        f" expected. Logprob: {plp}"
                    )
391

392
393
394
395
396
                assert rank_one_appears, (
                    f"No Logprob has rank 1"
                    " in the following Logprob"
                    f" dict: {pos_logprob_dict}"
                )
397
398
399
400
401
402

                # Validate prompt logprob detokenization
                for plp_tok in pos_logprob_dict:
                    # Confirm that prompt logprob decoded token matches
                    # the logprob token id at this sequence position
                    decoded_token = pos_logprob_dict[plp_tok].decoded_token
403
                    ref_decoded_token = _ref_convert_id_to_token(dtv.tokenizer, plp_tok)
404
405
406
407
                    assert decoded_token == ref_decoded_token, (
                        f"Prompt logprob token id {plp_tok} decodes to"
                        f" {ref_decoded_token} but Logprob decoded"
                        f" token is {decoded_token} instead"
408
409
                        f" (at position {idx})"
                    )
410
411
412
413
414
415
        else:
            # Prompt logprobs disabled for this request
            assert prompt_logprobs is None


@pytest.mark.parametrize(
416
417
418
419
420
421
    "request_output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]
)
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs", [None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_logprobs_processor(
    request_output_kind: RequestOutputKind,
422
423
    num_sample_logprobs: int | None,
    num_prompt_logprobs: int | None,
424
425
426
    dummy_test_vectors,
):
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
427
428
    engine_core = MockEngineCore(
        tokens_list=dummy_test_vectors.generation_tokens,
429
430
431
        generated_logprobs_raw=None
        if num_sample_logprobs is None
        else dummy_test_vectors.generation_logprobs,
432
        prompt_logprobs_raw=None
433
434
435
        if num_prompt_logprobs is None
        else dummy_test_vectors.prompt_logprobs,
    )
436
437
438

    # Make N requests.
    request_id_list = [
439
        f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
440
441
    ]
    requests = [
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
        EngineCoreRequest(
            request_id=request_id_list[idx],
            prompt_token_ids=prompt_tokens,
            mm_features=None,
            eos_token_id=None,
            arrival_time=0,
            lora_request=None,
            cache_salt=None,
            data_parallel_rank=None,
            sampling_params=SamplingParams(
                skip_special_tokens=False,
                spaces_between_special_tokens=False,
                output_kind=request_output_kind,
                stop=[],
                include_stop_str_in_output=False,
                logprobs=num_sample_logprobs,
                prompt_logprobs=num_prompt_logprobs,
            ),
            pooling_params=None,
        )
462
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
463
464
465
    ]

    # Add requests to the detokenizer.
466
467
    for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
        output_processor.add_request(request, prompt)
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491

    gen_tokens = {}
    gen_logprobs = {}
    gen_prompt_logprobs = {}
    gen_cumulative_logprobs = {}
    while True:
        # Mock output from the EngineCore.
        outputs = engine_core.get_outputs()
        if len(outputs) == 0:
            break

        # Step the logprobs processor.
        processed_outputs = output_processor.process_outputs(outputs)
        request_outputs = processed_outputs.request_outputs
        requests_to_abort = processed_outputs.reqs_to_abort
        assert len(requests_to_abort) == 0

        # Update tracking.
        for request_output in request_outputs:
            request_id = request_output.request_id
            new_tokens = request_output.outputs[0].token_ids
            prompt_logprobs = request_output.prompt_logprobs
            logprobs = request_output.outputs[0].logprobs
            gen_cumulative_logprobs[request_id] = request_output.outputs[
492
493
                0
            ].cumulative_logprob
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
            if request_id not in gen_logprobs:
                # Start tracking sample and prompt logprobs for this request
                gen_tokens[request_id] = new_tokens
                gen_logprobs[request_id] = logprobs
                gen_prompt_logprobs[request_id] = prompt_logprobs
            else:
                # Extend logprobs tracker
                gen_tokens[request_id].extend(new_tokens)
                lp = gen_logprobs[request_id]
                plp = gen_prompt_logprobs[request_id]
                if lp:
                    lp.extend(logprobs)
                if plp:
                    plp.extend(prompt_logprobs)

    # Confirmed tracked logprobs match what we expect
510
511
512
513
514
515
516
517
518
519
    _validate_logprobs(
        gen_tokens,
        gen_logprobs,
        gen_prompt_logprobs,
        gen_cumulative_logprobs,
        dummy_test_vectors,
        request_id_list,
        num_sample_logprobs,
        num_prompt_logprobs,
    )
520
521
522
523
524

    assert output_processor.get_num_unfinished_requests() == 0
    assert not output_processor.has_unfinished_requests()


525
526
@pytest.mark.parametrize(
    "include_stop_str_in_output,stop_token_type,ignore_eos,num_sample_logprobs",
527
528
529
530
531
532
533
534
535
536
537
538
    [
        (False, "stop_token_ids", False, None),
        (True, "stop_token_ids", False, None),
        (False, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
        (True, "stop_token_ids", False, NUM_SAMPLE_LOGPROBS_UNDER_TEST),
        (False, "eos_token_id", False, None),
        (True, "eos_token_id", False, None),
        (False, "eos_token_id", True, None),
    ],
)
def test_stop_token(
    include_stop_str_in_output: bool,
539
    num_sample_logprobs: int | None,
540
541
542
543
    stop_token_type: str,
    ignore_eos: bool,
    dummy_test_vectors,
):
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
    """Test output processor EOS/stop token handling.

    Send mock engine core request to mock engine core and pass core outputs
    to output processor. Validate output processor tokens, text and
    (if enabled) sample logprobs. Batch-size one.

    The test emulates a scenario where a model outputs text tokens followed
    by two identical control tokens:
    <token><token>...<token><control><control>

    If EOS is under test, the control tokens are EOS; otherwise, they are
    some other token id.

    Test behavior:

    * If EOS is under test and `ignore_eos=True`, the detokenized string
      should be <token><token>...<token><control><control> and the finish
      reason should be "length" (i.e. no stop occurs)

    * else, if `include_stop_str_in_output==True`, the detokenized
      string should be <token><token>...<token><control> and the finish
      reason should be "stop" (i.e. first control token causes stop
      and is represented in output text)

568
    * else, the detokenized string should be
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
      <token><token>...<token> and the finish reason should be "stop"
      (i.e. first control token causes stop but is not represented
      in output text.)

    Note: some test details are tuned for meta-llama/Llama-3.2-1B,
    another model should work only if the test is modified.

    Args:
        include_stop_str_in_output: stop token str appears in output text
        num_sample_logprobs: number of sample logprobs (`None` for no logprobs)
        stop_token_type: "eos_token_id" for EOS, "stop_token_ids" for stop token
        ignore_eos: if True, EOS stops are disabled
        dummy_test_vectors: dummy engine core outputs and other data structures
    """
    model_id = dummy_test_vectors.tokenizer.name_or_path
584
585
586
587
    if model_id != "meta-llama/Llama-3.2-1B":
        raise AssertionError(
            f"Test requires meta-llama/Llama-3.2-1B but {model_id} is in use."
        )
588
589
590
591
592
593
594
595
596
597
    do_logprobs = num_sample_logprobs is not None
    # EOS under test; if False, stop_token_ids under test
    is_eos_test = stop_token_type == "eos_token_id"
    # EOS under test but ignore_eos enabled
    is_eos_ignore_test = is_eos_test and ignore_eos
    eos_token_id = (
        dummy_test_vectors.tokenizer.eos_token_id if is_eos_test else None
    )  # '<|end_of_text|>'
    stop_token_ids = [128009] if not is_eos_test else None  # '<|eot_id|>'

598
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
599
    # Dummy engine core outputs, with control tokens suffixed to test stops
600
    suffix_token = [eos_token_id] if is_eos_test else stop_token_ids
601
602
    assert suffix_token is not None and isinstance(suffix_token[0], int)
    generation_string = dummy_test_vectors.generation_strings[0]
603
    generation_tokens = dummy_test_vectors.generation_tokens[0] + 2 * suffix_token
604
    if do_logprobs:
605
606
607
        generation_logprobs = dummy_test_vectors.generation_logprobs[0] + 2 * [
            dummy_test_vectors.generation_logprobs[0][-1]
        ]
608
609
610
611
612
613
614
615
    prompt_string = dummy_test_vectors.prompt_strings[0]
    prompt_tokens = dummy_test_vectors.prompt_tokens[0]
    engine_core = MockEngineCore(
        tokens_list=[generation_tokens],
        generated_logprobs_raw=[generation_logprobs] if do_logprobs else None,
        prompt_logprobs_raw=None,
        eos_token_id=eos_token_id,
        stop_token_ids=stop_token_ids,
616
617
        ignore_eos=ignore_eos,
    )
618
619
620
621
622
623

    # Make request.
    request_id = "request-0"
    request = EngineCoreRequest(
        request_id=request_id,
        prompt_token_ids=prompt_tokens,
624
        mm_features=None,
625
        eos_token_id=eos_token_id,
626
        arrival_time=0,
627
        lora_request=None,
628
        cache_salt=None,
629
        data_parallel_rank=None,
630
631
632
633
634
635
636
637
638
639
        sampling_params=SamplingParams(
            skip_special_tokens=False,
            spaces_between_special_tokens=False,
            output_kind=RequestOutputKind.DELTA,
            stop=[],
            stop_token_ids=stop_token_ids,
            include_stop_str_in_output=include_stop_str_in_output,
            logprobs=num_sample_logprobs,
            prompt_logprobs=None,
            ignore_eos=ignore_eos,
640
        ),
641
642
        pooling_params=None,
    )
643
644

    # Add request to the detokenizer.
645
    output_processor.add_request(request, prompt_string)
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666

    # Loop over engine core steps; run output processor
    gen_string = ""
    gen_tokens = []
    gen_logprobs = []
    while True:
        # Mock output from the EngineCore.
        outputs = engine_core.get_outputs()
        if len(outputs) == 0:
            break

        # Step the Detokenizer.
        processed_outputs = output_processor.process_outputs(outputs)
        request_outputs = processed_outputs.request_outputs
        assert len(request_outputs) == 1
        # Stop token does not rely on abort
        assert not processed_outputs.reqs_to_abort

        # Update tracking.
        request_output = request_outputs[0]
        if request_output.finished:
667
            finish_reason = "length" if is_eos_ignore_test else "stop"
668
669
670
671
672
673
674
675
            assert request_output.outputs[0].finish_reason == finish_reason

        gen_string += request_output.outputs[0].text
        gen_tokens.extend(request_output.outputs[0].token_ids)
        if do_logprobs:
            gen_logprobs.extend(request_output.outputs[0].logprobs)

    # Validate generated text
676
    control_token = "<|end_of_text|>" if is_eos_test else "<|eot_id|>"
677
678
679
680
681
682
683
684
685
    if is_eos_ignore_test:
        # Length-based stop; expect full string
        ref_str = generation_string + 2 * control_token
    elif include_stop_str_in_output:
        # Stop token triggered; include in output
        ref_str = generation_string + control_token
    else:
        # Stop token triggered but not in output
        ref_str = generation_string
686
    assert gen_string == ref_str, f"{gen_string=}, {ref_str=}"
687
688
689
690
691
692

    if do_logprobs:
        # Validate number of sample logprobs
        num_tokens = len(gen_tokens)
        num_logprobs = len(gen_logprobs)
        assert num_tokens == num_logprobs, (
693
694
            f"Token count ({num_tokens}) != logprobs count ({num_logprobs})"
        )
695
696
697
698
699
700

    # Check requests are finished
    assert output_processor.get_num_unfinished_requests() == 0
    assert not output_processor.has_unfinished_requests()


701
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
702
703
704
@pytest.mark.parametrize("num_sample_logprobs", [None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
def test_stop_string(
    include_stop_str_in_output: bool,
705
    num_sample_logprobs: int | None,
706
707
708
    dummy_test_vectors,
):
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=False)
709
710
711
    engine_core = MockEngineCore(
        tokens_list=dummy_test_vectors.generation_tokens,
        generated_logprobs_raw=dummy_test_vectors.generation_logprobs
712
713
714
715
        if num_sample_logprobs
        else None,
        prompt_logprobs_raw=None,
    )
716
717

    # Make N requests.
718
    request_id_list = [
719
        f"request-{idx}" for idx in range(len(dummy_test_vectors.prompt_strings))
720
    ]
721
    requests = [
722
        EngineCoreRequest(
723
            request_id=request_id_list[idx],
724
            prompt_token_ids=prompt_tokens,
725
            mm_features=None,
726
            eos_token_id=None,
727
            arrival_time=0,
728
            lora_request=None,
729
            cache_salt=None,
730
            data_parallel_rank=None,
731
732
733
734
735
736
            sampling_params=SamplingParams(
                skip_special_tokens=False,
                spaces_between_special_tokens=False,
                output_kind=RequestOutputKind.DELTA,
                stop=STOP_STRINGS,
                include_stop_str_in_output=include_stop_str_in_output,
737
                logprobs=num_sample_logprobs,
738
                prompt_logprobs=None,
739
            ),
740
741
            pooling_params=None,
        )
742
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
743
744
745
    ]

    # Add requests to the detokenizer.
746
747
    for request, prompt in zip(requests, dummy_test_vectors.prompt_strings):
        output_processor.add_request(request, prompt)
748
749

    gen_strings = {}
750
751
752
753
    gen_tokens = {}
    gen_logprobs = {}
    gen_prompt_logprobs = {}
    gen_cumulative_logprobs = {}
754
755
756
757
758
759
760
761
    aborted = []
    while True:
        # Mock output from the EngineCore.
        outputs = engine_core.get_outputs()
        if len(outputs) == 0:
            break

        # Step the Detokenizer.
762
763
764
        processed_outputs = output_processor.process_outputs(outputs)
        request_outputs = processed_outputs.request_outputs
        requests_to_abort = processed_outputs.reqs_to_abort
765
766
767
768
769
770
771
772
773
774
775
776
        for request_output in request_outputs:
            # If aborted, we should not get a request output.
            assert request_output.request_id not in aborted
        aborted.extend(requests_to_abort)

        # Update tracking.
        for request_output in request_outputs:
            if request_output.finished:
                assert request_output.outputs[0].finish_reason == "stop"

            request_id = request_output.request_id
            new_text = request_output.outputs[0].text
777
778
779
780
            new_tokens = request_output.outputs[0].token_ids
            prompt_logprobs = request_output.prompt_logprobs
            logprobs = request_output.outputs[0].logprobs
            gen_cumulative_logprobs[request_id] = request_output.outputs[
781
782
                0
            ].cumulative_logprob
783
784
            if request_id not in gen_strings:
                gen_strings[request_id] = new_text
785
786
787
                gen_tokens[request_id] = new_tokens
                gen_logprobs[request_id] = logprobs
                gen_prompt_logprobs[request_id] = prompt_logprobs
788
789
            else:
                gen_strings[request_id] += new_text
790
791
792
793
794
795
796
                gen_tokens[request_id].extend(new_tokens)
                lp = gen_logprobs[request_id]
                plp = gen_prompt_logprobs[request_id]
                if lp:
                    lp.extend(logprobs)
                if plp:
                    plp.extend(prompt_logprobs)
797
798

    # Confirmed tracked values matches what we expected.
799
    for idx, (ref_gen_str, stop_str) in enumerate(
800
801
        zip(dummy_test_vectors.generation_strings, STOP_STRINGS)
    ):
802
803
804
805
806
807
808
809
810
811
812
813
814
        # Request should be aborted.
        request_id = f"request-{idx}"
        assert request_id in aborted

        # Collected values that were generated.
        gen_str = gen_strings[request_id]

        # Construct reference strings.
        stop_str_idx = ref_gen_str.find(stop_str)
        ref_str_exc_stop = ref_gen_str[:stop_str_idx]
        ref_str_inc_stop = ref_gen_str[:stop_str_idx] + stop_str

        if include_stop_str_in_output:
815
            assert gen_str == ref_str_inc_stop, f"{gen_str=}, {ref_str_inc_stop=}"
816
        else:
817
            assert gen_str == ref_str_exc_stop, f"{gen_str=}, {ref_str_exc_stop=}"
818

819
    # Confirmed tracked logprobs match what we expect
820
821
822
823
824
825
826
827
828
829
    _validate_logprobs(
        gen_tokens,
        gen_logprobs,
        gen_prompt_logprobs,
        gen_cumulative_logprobs,
        dummy_test_vectors,
        request_id_list,
        num_sample_logprobs,
        None,
    )
830

831
832
833
834
    assert output_processor.get_num_unfinished_requests() == 0
    assert not output_processor.has_unfinished_requests()


835
def test_iteration_stats(dummy_test_vectors):
836
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
837
    engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
838
    engine_core_timestamp = time.monotonic()
839
840
841
842
843
844

    # Make N requests.
    requests = [
        EngineCoreRequest(
            request_id=f"request-{idx}",
            prompt_token_ids=prompt_tokens,
845
            mm_features=None,
846
            eos_token_id=None,
847
            arrival_time=0,
848
            lora_request=None,
849
            cache_salt=None,
850
            data_parallel_rank=None,
851
            sampling_params=SamplingParams(),
852
            pooling_params=None,
853
854
        )
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
855
856
857
    ]

    # Add all requests except one to the OutputProcessor.
858
    num_active = len(dummy_test_vectors.generation_tokens) - 1
859
    for request in requests[:num_active]:
860
        output_processor.add_request(request, None)
861
862
863
864
    inactive_request = requests[num_active]

    # First iteration has 2 prefills.
    outputs = engine_core.get_outputs()[:num_active]
865
    iteration_stats = IterationStats()
866
867
868
869
870
871
872
    output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
    total_prompt_tokens = sum(
        [
            len(prompt_tokens)
            for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
        ]
    )
873
874
875
876
877
878

    assert iteration_stats.num_prompt_tokens == total_prompt_tokens
    assert iteration_stats.num_generation_tokens == num_active

    # Just decodes in this step.
    outputs = engine_core.get_outputs()[:num_active]
879
    iteration_stats = IterationStats()
880
    output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
881
882
883
884
885

    assert iteration_stats.num_prompt_tokens == 0
    assert iteration_stats.num_generation_tokens == num_active

    # Add a new request - prefill and 2 decodes in this step.
886
    output_processor.add_request(inactive_request, None)
887
888
    num_active += 1
    outputs = engine_core.get_outputs()[:num_active]
889
    iteration_stats = IterationStats()
890
    output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
891
    total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
892
893
894
895
896
897

    assert iteration_stats.num_prompt_tokens == total_prompt_tokens
    assert iteration_stats.num_generation_tokens == num_active

    # Just decodes in this step.
    outputs = engine_core.get_outputs()[:num_active]
898
    iteration_stats = IterationStats()
899
    output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats)
900
901
902

    assert iteration_stats.num_prompt_tokens == 0
    assert iteration_stats.num_generation_tokens == num_active
903
904


905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
@pytest.mark.parametrize("log_stats", [True, False])
def test_lora_request_tracking(log_stats: bool, dummy_test_vectors):
    """Test LoRA request lifecycle tracking through waiting -> running -> finished."""
    output_processor = OutputProcessor(
        dummy_test_vectors.tokenizer, log_stats=log_stats
    )
    engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
    engine_core_timestamp = time.monotonic()

    # Create LoRA requests
    lora1 = LoRARequest(lora_name="lora-1", lora_int_id=1, lora_path="/path/to/lora1")
    lora2 = LoRARequest(lora_name="lora-2", lora_int_id=2, lora_path="/path/to/lora2")

    # Create requests with different LoRA adapters:
    # - request-0: lora-1
    # - request-1: lora-2
    # - request-2: None (no LoRA)
    lora_assignments = [lora1, lora2, None]
    requests = [
        EngineCoreRequest(
            request_id=f"request-{idx}",
            prompt_token_ids=prompt_tokens,
            mm_features=None,
            eos_token_id=None,
            arrival_time=0,
            lora_request=lora_assignments[idx],
            cache_salt=None,
            data_parallel_rank=None,
            sampling_params=SamplingParams(),
            pooling_params=None,
        )
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
    ]

    # Add all requests to the OutputProcessor
    for request in requests:
        output_processor.add_request(request, None)

    # First iteration: process outputs with QUEUED events
    outputs = EngineCoreOutputs(
        outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
    )
    for output in outputs.outputs:
        output.events = [
            EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, engine_core_timestamp)
        ]

    iteration_stats = IterationStats() if log_stats else None
    output_processor.process_outputs(
        outputs.outputs, engine_core_timestamp, iteration_stats
    )
    output_processor.update_scheduler_stats(outputs.scheduler_stats)

    if log_stats:
        # Verify waiting counts
        assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 1
        assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 1
        assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 0
        assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 0
        # Verify internal state
        assert len(output_processor.lora_states.requests) == 2
        assert "lora-1" in output_processor.lora_states.requests
        assert "lora-2" in output_processor.lora_states.requests
    else:
        # When log_stats=False, no tracking should occur
        assert iteration_stats is None
        assert len(output_processor.lora_states.requests) == 0

    # Second iteration: process outputs with SCHEDULED events
    outputs = EngineCoreOutputs(
        outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
    )
    for output in outputs.outputs:
        output.events = [
            EngineCoreEvent.new_event(
                EngineCoreEventType.SCHEDULED, engine_core_timestamp
            )
        ]

    iteration_stats = IterationStats() if log_stats else None
    output_processor.process_outputs(
        outputs.outputs, engine_core_timestamp, iteration_stats
    )
    output_processor.update_scheduler_stats(outputs.scheduler_stats)

    if log_stats:
        # Verify running counts
        assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-1") == 0
        assert outputs.scheduler_stats.waiting_lora_adapters.get("lora-2") == 0
        assert outputs.scheduler_stats.running_lora_adapters.get("lora-1") == 1
        assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
    else:
        assert iteration_stats is None
        assert len(output_processor.lora_states.requests) == 0

    # Third iteration: finish request-0 (lora-1)
    outputs = EngineCoreOutputs(
        outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
    )
    # Find and mark request-0 as finished (it uses lora-1)
    for output in outputs.outputs:
        if output.request_id == "request-0":
            output.finish_reason = FinishReason.LENGTH
            break

    iteration_stats = IterationStats() if log_stats else None
    output_processor.process_outputs(
        outputs.outputs, engine_core_timestamp, iteration_stats
    )
    output_processor.update_scheduler_stats(outputs.scheduler_stats)

    if log_stats:
        # lora-1 should be removed since no requests remain
        assert "lora-1" not in output_processor.lora_states.requests
        # lora-2 should still be running
        assert outputs.scheduler_stats.running_lora_adapters.get("lora-2") == 1
        assert len(output_processor.lora_states.requests) == 1
    else:
        assert len(output_processor.lora_states.requests) == 0

    # Fourth iteration: finish request-1 (lora-2)
    outputs = EngineCoreOutputs(
        outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
    )
    # Find and mark request-1 as finished (it uses lora-2)
    for output in outputs.outputs:
        if output.request_id == "request-1":
            output.finish_reason = FinishReason.LENGTH
            break

    iteration_stats = IterationStats() if log_stats else None
    output_processor.process_outputs(
        outputs.outputs, engine_core_timestamp, iteration_stats
    )
    output_processor.update_scheduler_stats(outputs.scheduler_stats)

    if log_stats:
        # lora-2 should be removed since no requests remain
        assert "lora-2" not in output_processor.lora_states.requests
        assert len(outputs.scheduler_stats.running_lora_adapters) == 0
        assert len(output_processor.lora_states.requests) == 0
    else:
        assert len(output_processor.lora_states.requests) == 0

    # Finish the last request (no LoRA)
    outputs = EngineCoreOutputs(
        outputs=engine_core.get_outputs(), scheduler_stats=SchedulerStats()
    )
    # Find and mark request-2 as finished (it has no LoRA)
    for output in outputs.outputs:
        if output.request_id == "request-2":
            output.finish_reason = FinishReason.LENGTH
            break

    iteration_stats = IterationStats() if log_stats else None
    output_processor.process_outputs(
        outputs.outputs, engine_core_timestamp, iteration_stats
    )
    output_processor.update_scheduler_stats(outputs.scheduler_stats)

    # Verify all requests are finished
    assert output_processor.get_num_unfinished_requests() == 0


1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
@pytest.mark.asyncio
async def test_request_output_collector():
    NUM_REQS = 3
    TEXT = "a"

    def make_outputs() -> list[RequestOutput]:
        return [
            RequestOutput(
                request_id="my-request-id",
                prompt=None,
                prompt_token_ids=[1, 2, 3],
                prompt_logprobs=None,
                outputs=[
                    CompletionOutput(
                        index=0,
                        text=TEXT,
                        token_ids=[idx],
                        cumulative_logprob=(idx + 1 * 1.0),
1087
1088
                        logprobs=[{"a": idx, "b": idx}],
                        finish_reason="length" if (idx == NUM_REQS - 1) else None,
1089
1090
1091
                    )
                ],
                finished=(idx == NUM_REQS - 1),
1092
1093
            )
            for idx in range(NUM_REQS)
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
        ]

    collector = RequestOutputCollector(RequestOutputKind.DELTA)

    # CASE 1: Put then get.
    outputs = make_outputs()
    collector.put(outputs[0])
    output = await collector.get()
    assert not collector.ready.is_set()
    assert collector.output is None
    assert output.outputs[0].text == "a"
    assert output.outputs[0].token_ids == [0]

    # CASE 2: 2 puts then get.
    num_to_put = 2
    outputs = make_outputs()
    for i in range(num_to_put):
        collector.put(outputs[i])
    output = await collector.get()
    assert not collector.ready.is_set()
    assert collector.output is None

    assert not output.finished
    # Text, token_ids, and logprobs should get merged.
    assert output.outputs[0].text == TEXT * num_to_put
1119
    for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
        assert tok_0 == tok_1
    assert len(output.outputs[0].logprobs) == num_to_put

    # Cumulative logprobs should be the last one.
    cumulative_logprob_expected = 1.0 * num_to_put
    assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected

    # CASE 3: Put all 3 (including a finished).
    num_to_put = 3
    outputs = make_outputs()
    for i in range(num_to_put):
        collector.put(outputs[i])
    output = await collector.get()
    assert not collector.ready.is_set()
    assert collector.output is None

    assert output.finished
    assert output.outputs[0].finish_reason == "length"
    # Text, token_ids, and logprobs should get merged.
    assert output.outputs[0].text == TEXT * num_to_put
1140
    for tok_0, tok_1 in zip(output.outputs[0].token_ids, list(range(num_to_put))):
1141
1142
1143
1144
1145
1146
        assert tok_0 == tok_1
    assert len(output.outputs[0].logprobs) == num_to_put

    # Cumulative logprobs should be the last one.
    cumulative_logprob_expected = 1.0 * num_to_put
    assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227


@pytest.mark.asyncio
async def test_cumulative_output_collector_n():
    """Test collector correctly handles multiple outputs by index."""
    collector = RequestOutputCollector(RequestOutputKind.CUMULATIVE)
    outputs = [
        RequestOutput(
            request_id="my-request-id",
            prompt=None,
            prompt_token_ids=[1, 2, 3],
            prompt_logprobs=None,
            outputs=[
                CompletionOutput(
                    index=0,
                    text="a",
                    token_ids=[0],
                    cumulative_logprob=None,
                    logprobs=None,
                    finish_reason=None,
                ),
                CompletionOutput(
                    index=1,
                    text="b",
                    token_ids=[1],
                    cumulative_logprob=None,
                    logprobs=None,
                    finish_reason=None,
                ),
            ],
            finished=False,
        ),
        RequestOutput(
            request_id="my-request-id",
            prompt=None,
            prompt_token_ids=[1, 2, 3],
            prompt_logprobs=None,
            outputs=[
                CompletionOutput(
                    index=0,
                    text="ab",
                    token_ids=[0, 1],
                    cumulative_logprob=None,
                    logprobs=None,
                    finish_reason=None,
                ),
                CompletionOutput(
                    index=2,
                    text="c",
                    token_ids=[2],
                    cumulative_logprob=None,
                    logprobs=None,
                    finish_reason=None,
                ),
            ],
            finished=False,
        ),
    ]
    for output in outputs:
        collector.put(output)

    # Get the output and check that the text and token_ids are correct.
    result = await collector.get()
    # We are expecting
    # [{index: 0, text: "ab"}, {index: 1, text: "b"}, {index: 2, text: "c"}]
    assert len(result.outputs) == 3
    # First is the one where index is 0
    first = [k for k in result.outputs if k.index == 0]
    assert len(first) == 1
    assert first[0].text == "ab"

    # Second is the one where index is 1
    second = [k for k in result.outputs if k.index == 1]
    assert len(second) == 1
    assert second[0].text == "b"
    assert second[0].token_ids == [1]

    # Third is the one where index is 2
    third = [k for k in result.outputs if k.index == 2]
    assert len(third) == 1
    assert third[0].text == "c"
1228
1229
1230
1231


@pytest.mark.parametrize("runner", ["generate", "pooling"])
def test_abort_requests(runner: str, dummy_test_vectors):
1232
    output_processor = OutputProcessor(dummy_test_vectors.tokenizer, log_stats=True)
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
    requests = [
        EngineCoreRequest(
            request_id=f"request-{idx}",
            prompt_token_ids=prompt_tokens,
            mm_features=None,
            eos_token_id=None,
            arrival_time=0,
            lora_request=None,
            cache_salt=None,
            data_parallel_rank=None,
            sampling_params=SamplingParams() if runner == "generate" else None,
1244
1245
1246
            pooling_params=PoolingParams(task="embed") if runner == "pooling" else None,
        )
        for idx, prompt_tokens in enumerate(dummy_test_vectors.prompt_tokens)
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
    ]

    for request in requests:
        if runner == "generate":
            output_kind = request.sampling_params.output_kind
        else:
            output_kind = request.pooling_params.output_kind
        queue = RequestOutputCollector(output_kind=output_kind)
        output_processor.add_request(request, None, queue=queue)

    for request in requests:
        output_processor.abort_requests([request.request_id])