sampler.py 21.6 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import (
9
    gather_from_tensor_model_parallel_region)
10
11
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
14

15

Woosuk Kwon's avatar
Woosuk Kwon committed
16
class Sampler(nn.Module):
17
18
19
20
21
22
23
24
25
26
27
28
29
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
    3. Apply presence and frequency penalties.
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
30

Woosuk Kwon's avatar
Woosuk Kwon committed
31
    def __init__(self, vocab_size: int) -> None:
32
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
40
        embedding_bias: Optional[torch.Tensor] = None,
41
    ) -> SamplerOutput:
42
43
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45

        # Get the logits for the next tokens.
46
47
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
50
51
        # Apply presence and frequency penalties.
        output_tokens = _get_output_tokens(input_metadata)
        assert len(output_tokens) == logits.shape[0]
52
53
        presence_penalties, frequency_penalties = _get_penalties(
            input_metadata)
54
55
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
56
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
57
                                  frequency_penalties)
58

59
60
61
62
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
63
64
65
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
66
67
68
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
        # Apply top-p and top-k truncation.
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
71
        assert len(top_ps) == len(top_ks) == logits.shape[0]
72
73
74
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
75
76
77
78
79
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
83

Woosuk Kwon's avatar
Woosuk Kwon committed
84
        # Sample the next tokens.
85
86
87
        return _sample(probs, logprobs, input_metadata)


88
89
90
91
92
93
94
95
96
97
98
99
100
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
    logits = gather_from_tensor_model_parallel_region(logits)
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


101
102
103
104
def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
105
    last_token_indices = {t: [] for t in SamplingType}
106
    start_idx = 0
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        sampling_type = sampling_params.sampling_type
        if i < input_metadata.num_prompts:
            assert len(seq_ids) == 1, "Prompt input should have only one seq."
            prompt_len = input_metadata.prompt_lens[i]
            last_token_indices[sampling_type].append(start_idx + prompt_len -
                                                     1)
            start_idx += prompt_len
        else:
            num_seqs = len(seq_ids)
            last_token_indices[sampling_type].extend(
                range(start_idx, start_idx + num_seqs))
            start_idx += num_seqs

    all_last_token_indices = []
    for sampling_type in SamplingType:
        all_last_token_indices.extend(last_token_indices[sampling_type])
    all_last_token_indices = torch.tensor(all_last_token_indices,
                                          dtype=torch.long,
                                          device=hidden_states.device)
    return hidden_states.index_select(0, all_last_token_indices)
129
130


131
def _get_penalties(
132
        input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
        if i < input_metadata.num_prompts:
            # A prompt input.
            presence_penalties.append(p)
            frequency_penalties.append(f)
        else:
            # A generation token.
            presence_penalties += [p] * len(seq_ids)
            frequency_penalties += [f] * len(seq_ids)
    return presence_penalties, frequency_penalties


151
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    output_tokens: List[List[int]] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, _ = seq_group
        if i < input_metadata.num_prompts:
            # A prompt input.
            # NOTE: While the prompt input usually has no output tokens,
            # it may have output tokens in the case of recomputation.
            seq_id = seq_ids[0]
            seq_data = input_metadata.seq_data[seq_id]
            output_tokens.append(seq_data.output_token_ids)
        else:
            # A generation token.
            for seq_id in seq_ids:
                seq_data = input_metadata.seq_data[seq_id]
                output_tokens.append(seq_data.output_token_ids)
    return output_tokens


def _apply_penalties(
    logits: torch.Tensor,
    output_tokens: List[List[int]],
    presence_penalties: List[float],
    frequency_penalties: List[float],
) -> torch.Tensor:
176
    num_seqs, vocab_size = logits.shape
177
178
179
180
181
    for i in range(num_seqs):
        if not output_tokens[i]:
            continue
        p = presence_penalties[i]
        f = frequency_penalties[i]
182
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
183
            continue
184
185
186
        break
    else:
        # Return early if all sequences have zero penalties.
187
188
        return logits

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    max_output_len = max(len(tokens) for tokens in output_tokens)
    padded_output_tokens = [
        tokens + [vocab_size] * (max_output_len - len(tokens))
        for tokens in output_tokens
    ]
    output_tokens_tensor = torch.tensor(padded_output_tokens,
                                        dtype=torch.long,
                                        device=logits.device)

    # Compute the bin counts for the output tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
                             device=logits.device)
    bin_counts.scatter_add_(1, output_tokens_tensor,
                            torch.ones_like(output_tokens_tensor))
    bin_counts = bin_counts[:, :vocab_size]  # Remove the padding bin.
206

207
208
209
210
211
212
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
213
214
215

    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
216
217
    logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
218
219
220
    return logits


221
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
222
223
224
225
226
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
227
        if temperature < _SAMPLING_EPS:
228
229
230
231
232
233
234
235
236
237
238
239
240
241
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0

        if i < input_metadata.num_prompts:
            # A prompt input.
            temperatures.append(temperature)
        else:
            # A generation token.
            temperatures += [temperature] * len(seq_ids)
    return temperatures


Woosuk Kwon's avatar
Woosuk Kwon committed
242
def _get_top_p_top_k(
243
    input_metadata: InputMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
    vocab_size: int,
) -> Tuple[List[float], List[int]]:
246
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
247
    top_ks: List[int] = []
248
249
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
254
        top_p = sampling_params.top_p
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
255
256
        if i < input_metadata.num_prompts:
            # A prompt input.
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
            top_ps.append(top_p)
            top_ks.append(top_k)
259
260
        else:
            # A generation token.
Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
            top_ps += [top_p] * len(seq_ids)
            top_ks += [top_k] * len(seq_ids)
    return top_ps, top_ks
264
265


Woosuk Kwon's avatar
Woosuk Kwon committed
266
def _apply_top_p_top_k(
267
    logits: torch.Tensor,
268
269
    top_ps: List[float],
    top_ks: List[int],
270
) -> torch.Tensor:
271
272
273
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275

    # Apply top-p.
276
277
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
278
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
279
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282

    # Apply top-k.
    # Create a mask for the top-k elements.
283
284
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
285
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
286
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288

    # Re-sort the probabilities.
289
290
291
292
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
293
294
295
296


def _get_topk_logprobs(
    logprobs: torch.Tensor,
Zhuohan Li's avatar
Zhuohan Li committed
297
    num_logprobs: Optional[int],
298
299
) -> List[Dict[int, float]]:
    num_seqs = logprobs.size(0)
Zhuohan Li's avatar
Zhuohan Li committed
300
    if num_logprobs is None or num_logprobs == 0:
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        return [{} for _ in range(num_seqs)]

    all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
                                                 num_logprobs,
                                                 dim=-1)
    all_topk_logprobs = all_topk_logprobs.cpu()
    all_topk_ids = all_topk_ids.cpu()
    all_token_to_logprob = []
    for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
        token_to_logprob: Dict[int, float] = {}
        for token_id, logprob in zip(topk_ids, topk_logprobs):
            token_to_logprob[token_id.item()] = logprob.item()
        all_token_to_logprob.append(token_to_logprob)
    return all_token_to_logprob


def _build_sequence_outputs(
    parent_ids: List[int],
    next_token_ids: List[int],
    selected_token_logprobs: torch.Tensor,
    parent_seq_ids: List[int],
    parent_logprobs: torch.Tensor,
    num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
    # Get top-k log probabilities for the next tokens.
    next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
    seq_outputs: List[SequenceOutputs] = []
    for parent_id, next_token_id, token_logprob in zip(
            parent_ids, next_token_ids, selected_token_logprobs):
        output_logprobs = next_logprobs[parent_id].copy()
        output_logprobs[next_token_id] = token_logprob
        seq_outputs.append(
            SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
                            output_logprobs))
    return seq_outputs
336
337


338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
    samples = torch.argmax(logprobs, dim=-1).cpu()
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
        next_token_ids = [samples[sample_idx].item()]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
361
    probs: torch.Tensor,
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
    max_best_of = 1
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        if is_prompt:
            seq_ids, sampling_params = seq_group
            max_best_of = max(max_best_of, sampling_params.best_of)
    random_samples = torch.multinomial(probs,
                                       num_samples=max_best_of,
                                       replacement=True).cpu()
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == probs.size(0)
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
399
    logprobs: torch.Tensor,
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
    # Note: Beam search is not vectorized, so its speed can be slower than
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
446
447
448
449
450
451


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
452
) -> SamplerOutput:
453
454
    categorized_seq_group_ids = {t: [] for t in SamplingType}
    category_num_tokens = {t: 0 for t in SamplingType}
455
456
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
        num_seqs = len(seq_ids)
        category_num_tokens[sampling_type] += num_seqs

    seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
    category_start_idx = 0
    for sampling_type in SamplingType:
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
        num_tokens = category_num_tokens[sampling_type]
        if num_tokens == 0:
            continue
        category_logprobs = logprobs[category_start_idx:category_start_idx +
                                     num_tokens]
        category_probs = probs[category_start_idx:category_start_idx +
                               num_tokens]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, category_logprobs)
        elif sampling_type == SamplingType.RANDOM:
            sample_results = _random_sample(seq_groups, is_prompts,
                                            category_probs)
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups, is_prompts,
                                                 input_metadata.seq_data,
                                                 category_logprobs)
484
        else:
485
486
487
488
489
490
491
492
493
494
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

        # Batched query for logprobs of selected token
        batched_logprobs_query_seq_indices: List[int] = []
        batched_logprobs_query_token_indices: List[int] = []
        sample_idx = 0
        for seq_group_id, seq_group, sample_result in zip(
                seq_group_ids, seq_groups, sample_results):
            seq_ids, sampling_params = seq_group
            next_token_ids, parent_ids = sample_result
495
            num_parent_seqs = len(seq_ids)
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
            batched_logprobs_query_seq_indices.extend(
                [sample_idx + parent_id for parent_id in parent_ids])
            batched_logprobs_query_token_indices.extend(next_token_ids)
            sample_idx += num_parent_seqs
        assert sample_idx == num_tokens
        batched_logprobs_query_result = category_logprobs[[
            batched_logprobs_query_seq_indices,
            batched_logprobs_query_token_indices
        ]].tolist()

        # Build the sequence outputs.
        sample_idx = 0
        result_idx = 0
        for seq_group_id, seq_group, sample_result in zip(
                seq_group_ids, seq_groups, sample_results):
            seq_ids, sampling_params = seq_group
            next_token_ids, parent_ids = sample_result
            num_results = len(next_token_ids)
            num_parent_seqs = len(seq_ids)
            parent_logprobs = category_logprobs[sample_idx:sample_idx +
                                                num_parent_seqs]
            selected_token_logprobs = batched_logprobs_query_result[
                result_idx:result_idx + num_results]
            seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
                                                 selected_token_logprobs,
                                                 seq_ids, parent_logprobs,
                                                 sampling_params.logprobs)
            seq_outputs_dict[seq_group_id] = seq_output
            sample_idx += num_parent_seqs
            result_idx += num_results
        assert sample_idx == num_tokens
        category_start_idx += num_tokens

    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]