sampler.py 25.7 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.model_executor.input_metadata import InputMetadata
8
9
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_gather)
10
from vllm.sampling_params import SamplingParams, SamplingType
11
12
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
                           SequenceData, SequenceGroupOutputs, SequenceOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Sampler(nn.Module):
18
19
20
21
22
23
24
25
26
27
28
29
30
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
    3. Apply presence and frequency penalties.
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    def __init__(self, vocab_size: int) -> None:
33
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
41
        embedding_bias: Optional[torch.Tensor] = None,
42
    ) -> SamplerOutput:
43
44
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

        # Get the logits for the next tokens.
47
48
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
51
52
        # Apply presence and frequency penalties.
        output_tokens = _get_output_tokens(input_metadata)
        assert len(output_tokens) == logits.shape[0]
ljss's avatar
ljss committed
53
54
        presence_penalties, frequency_penalties, repetition_penalties = (
            _get_penalties(input_metadata))
55
56
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
ljss's avatar
ljss committed
57
        assert len(repetition_penalties) == logits.shape[0]
58
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
ljss's avatar
ljss committed
59
                                  frequency_penalties, repetition_penalties)
60

61
62
63
64
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
65
66
67
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
68
69
70
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
        # Apply top-p and top-k truncation.
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
73
        assert len(top_ps) == len(top_ks) == logits.shape[0]
74
75
76
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
77
78
79
80
81
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
82
83
84
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
85

Woosuk Kwon's avatar
Woosuk Kwon committed
86
        # Sample the next tokens.
87
88
89
90
91
92
        sample_results = _sample(probs, logprobs, input_metadata)
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
            logprobs, input_metadata, sample_results)
        return _build_sampler_output(sample_results, input_metadata,
                                     prompt_logprobs, sample_logprobs)
93
94


95
96
97
98
99
100
101
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
102
    logits = tensor_model_parallel_all_gather(logits)
103
104
105
106
107
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


108
109
110
111
def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
112
    selected_token_indices: List[int] = []
113
    start_idx = 0
114
    for i, seq_group in enumerate(input_metadata.seq_groups):
115
        seq_ids, sampling_params = seq_group
116
117
118
        if i < input_metadata.num_prompts:
            assert len(seq_ids) == 1, "Prompt input should have only one seq."
            prompt_len = input_metadata.prompt_lens[i]
119
120
121
122
            if sampling_params.prompt_logprobs is not None:
                selected_token_indices.extend(
                    range(start_idx, start_idx + prompt_len - 1))
            selected_token_indices.append(start_idx + prompt_len - 1)
123
            start_idx += input_metadata.max_prompt_len
124
125
        else:
            num_seqs = len(seq_ids)
126
127
            selected_token_indices.extend(
                range(start_idx, start_idx + num_seqs))
128
129
            start_idx += num_seqs

130
131
132
    selected_token_indices = torch.tensor(selected_token_indices,
                                          dtype=torch.long,
                                          device=hidden_states.device)
133
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
134
    return hidden_states.index_select(0, selected_token_indices)
135
136


137
def _get_penalties(
ljss's avatar
ljss committed
138
139
    input_metadata: InputMetadata
) -> Tuple[List[float], List[float], List[float]]:
140
141
142
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
ljss's avatar
ljss committed
143
    repetition_penalties: List[float] = []
144
    for i, seq_group in enumerate(input_metadata.seq_groups):
145
146
147
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
ljss's avatar
ljss committed
148
        r = sampling_params.repetition_penalty
149
150
151
152
153
154
155
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: We do not apply presence and frequency penalties for the
            # prompt token positions where we don't sample new tokens.
            prompt_len = input_metadata.prompt_lens[i]
            presence_penalties += [0] * (prompt_len - 1)
            frequency_penalties += [0] * (prompt_len - 1)
ljss's avatar
ljss committed
156
            repetition_penalties += [1] * (prompt_len - 1)
157
158
        presence_penalties += [p] * len(seq_ids)
        frequency_penalties += [f] * len(seq_ids)
ljss's avatar
ljss committed
159
160
        repetition_penalties += [r] * len(seq_ids)
    return presence_penalties, frequency_penalties, repetition_penalties
161
162


163
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
164
    output_tokens: List[List[int]] = []
165
166
167
168
169
170
171
172
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need output tokens to
            # compute penalties.
            prompt_len = input_metadata.prompt_lens[i]
            output_tokens.extend([] for _ in range(prompt_len - 1))
173
        for seq_id in seq_ids:
174
175
176
177
178
179
180
181
182
183
            seq_data = input_metadata.seq_data[seq_id]
            output_tokens.append(seq_data.output_token_ids)
    return output_tokens


def _apply_penalties(
    logits: torch.Tensor,
    output_tokens: List[List[int]],
    presence_penalties: List[float],
    frequency_penalties: List[float],
ljss's avatar
ljss committed
184
    repetition_penalties: List[float],
185
) -> torch.Tensor:
186
    num_seqs, vocab_size = logits.shape
187
188
189
190
191
    for i in range(num_seqs):
        if not output_tokens[i]:
            continue
        p = presence_penalties[i]
        f = frequency_penalties[i]
ljss's avatar
ljss committed
192
193
194
        r = repetition_penalties[i]
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
                r - 1.0) < _SAMPLING_EPS:
195
            continue
196
197
198
        break
    else:
        # Return early if all sequences have zero penalties.
199
200
        return logits

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    max_output_len = max(len(tokens) for tokens in output_tokens)
    padded_output_tokens = [
        tokens + [vocab_size] * (max_output_len - len(tokens))
        for tokens in output_tokens
    ]
    output_tokens_tensor = torch.tensor(padded_output_tokens,
                                        dtype=torch.long,
                                        device=logits.device)

    # Compute the bin counts for the output tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
                             device=logits.device)
    bin_counts.scatter_add_(1, output_tokens_tensor,
                            torch.ones_like(output_tokens_tensor))
    bin_counts = bin_counts[:, :vocab_size]  # Remove the padding bin.
ljss's avatar
ljss committed
218
    mask = bin_counts > 0
219

ljss's avatar
ljss committed
220
221
222
    repetition_penalties = torch.tensor(repetition_penalties,
                                        dtype=logits.dtype,
                                        device=logits.device)
223
224
225
226
227
228
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
229

ljss's avatar
ljss committed
230
231
232
233
234
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
    repetition_penalties[~mask] = 1.0
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

235
236
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
237
    logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
ljss's avatar
ljss committed
238
    logits -= presence_penalties.unsqueeze(dim=1) * mask
239
240
241
    return logits


242
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
243
244
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
245
    for i, seq_group in enumerate(input_metadata.seq_groups):
246
247
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
248
        if temperature < _SAMPLING_EPS:
249
250
251
252
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0
253
254
255
256
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            prompt_len = input_metadata.prompt_lens[i]
            temperatures += [temperature] * (prompt_len - 1)
257
        temperatures += [temperature] * len(seq_ids)
258
259
260
    return temperatures


Woosuk Kwon's avatar
Woosuk Kwon committed
261
def _get_top_p_top_k(
262
    input_metadata: InputMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
    vocab_size: int,
) -> Tuple[List[float], List[int]]:
265
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
266
    top_ks: List[int] = []
267
    for i, seq_group in enumerate(input_metadata.seq_groups):
268
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
273
        top_p = sampling_params.top_p
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
274
275
276
277
278
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            prompt_len = input_metadata.prompt_lens[i]
            top_ps += [top_p] * (prompt_len - 1)
            top_ks += [top_k] * (prompt_len - 1)
279
280
        top_ps += [top_p] * len(seq_ids)
        top_ks += [top_k] * len(seq_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
281
    return top_ps, top_ks
282
283


Woosuk Kwon's avatar
Woosuk Kwon committed
284
def _apply_top_p_top_k(
285
    logits: torch.Tensor,
286
287
    top_ps: List[float],
    top_ks: List[int],
288
) -> torch.Tensor:
289
290
291
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293

    # Apply top-p.
294
295
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
296
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
297
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300

    # Apply top-k.
    # Create a mask for the top-k elements.
301
302
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
303
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
304
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306

    # Re-sort the probabilities.
307
308
309
310
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
311
312


313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
    samples = torch.argmax(logprobs, dim=-1).cpu()
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
        next_token_ids = [samples[sample_idx].item()]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
336
    probs: torch.Tensor,
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
    max_best_of = 1
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        if is_prompt:
            seq_ids, sampling_params = seq_group
            max_best_of = max(max_best_of, sampling_params.best_of)
    random_samples = torch.multinomial(probs,
                                       num_samples=max_best_of,
                                       replacement=True).cpu()
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == probs.size(0)
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
374
    logprobs: torch.Tensor,
375
376
377
378
379
380
381
382
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
383
    # NOTE: Beam search is not vectorized, so its speed can be slower than
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
421
422
423
424
425
426


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
427
) -> List[Tuple[List[int], List[int]]]:
428
    categorized_seq_group_ids = {t: [] for t in SamplingType}
429
    categorized_sample_indices = {t: [] for t in SamplingType}
430
    start_idx = 0
431
432
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
433
        sampling_type = sampling_params.sampling_type
434
435
436
437
438
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need sample, skip
            prompt_len = input_metadata.prompt_lens[i]
            start_idx += prompt_len - 1
439
440
        categorized_seq_group_ids[sampling_type].append(i)
        num_seqs = len(seq_ids)
441
        categorized_sample_indices[sampling_type].extend(
442
443
            range(start_idx, start_idx + num_seqs))
        start_idx += num_seqs
444
445

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
446
447
448
449
    for sampling_type in SamplingType:
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
450
451
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
452
453
454
        if num_tokens == 0:
            continue
        if sampling_type == SamplingType.GREEDY:
455
            category_logprobs = logprobs[sample_indices]
456
457
            sample_results = _greedy_sample(seq_groups, category_logprobs)
        elif sampling_type == SamplingType.RANDOM:
458
            category_probs = probs[sample_indices]
459
460
461
            sample_results = _random_sample(seq_groups, is_prompts,
                                            category_probs)
        elif sampling_type == SamplingType.BEAM:
462
            category_logprobs = logprobs[sample_indices]
463
464
465
            sample_results = _beam_search_sample(seq_groups, is_prompts,
                                                 input_metadata.seq_data,
                                                 category_logprobs)
466
        else:
467
            raise ValueError(f"Unsupported sampling type: {sampling_type}")
468
        sample_results_dict.update(zip(seq_group_ids, sample_results))
469

470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    sample_results = [
        sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
            zip(input_metadata.seq_groups, sample_results)):
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
            prompt_len = input_metadata.prompt_lens[i]
            prompt_tokens = input_metadata.seq_data[
                seq_ids[0]].prompt_token_ids
499
            batched_logprobs_query_seq_indices.extend(
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
    ]].cpu()

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
            zip(input_metadata.seq_groups, sample_results)):
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
            prompt_len = input_metadata.prompt_lens[i]
            prompt_tokens = input_metadata.seq_data[
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
                group_prompt_logprobs.append(prompt_logprobs_dict)
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
            group_sample_logprobs.append(sample_logprobs_dict)
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
    input_metadata: InputMetadata,
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
         group_sample_logprobs) in zip(input_metadata.seq_groups,
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
                SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
        sampler_output.append(
            SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
    return sampler_output