sampler.py 24.8 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.model_executor.input_metadata import InputMetadata
8
9
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_gather)
10
from vllm.sampling_params import SamplingParams, SamplingType
11
12
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
                           SequenceData, SequenceGroupOutputs, SequenceOutputs)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Sampler(nn.Module):
18
19
20
21
22
23
24
25
26
27
28
29
30
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
    3. Apply presence and frequency penalties.
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    def __init__(self, vocab_size: int) -> None:
33
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
41
        embedding_bias: Optional[torch.Tensor] = None,
42
    ) -> SamplerOutput:
43
44
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

        # Get the logits for the next tokens.
47
48
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
51
52
        # Apply presence and frequency penalties.
        output_tokens = _get_output_tokens(input_metadata)
        assert len(output_tokens) == logits.shape[0]
53
54
        presence_penalties, frequency_penalties = _get_penalties(
            input_metadata)
55
56
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
57
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
58
                                  frequency_penalties)
59

60
61
62
63
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
64
65
66
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
67
68
69
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        # Apply top-p and top-k truncation.
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
72
        assert len(top_ps) == len(top_ks) == logits.shape[0]
73
74
75
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
76
77
78
79
80
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
84

Woosuk Kwon's avatar
Woosuk Kwon committed
85
        # Sample the next tokens.
86
87
88
89
90
91
        sample_results = _sample(probs, logprobs, input_metadata)
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
            logprobs, input_metadata, sample_results)
        return _build_sampler_output(sample_results, input_metadata,
                                     prompt_logprobs, sample_logprobs)
92
93


94
95
96
97
98
99
100
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
101
    logits = tensor_model_parallel_all_gather(logits)
102
103
104
105
106
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


107
108
109
110
def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
111
    selected_token_indices: List[int] = []
112
    start_idx = 0
113
    for i, seq_group in enumerate(input_metadata.seq_groups):
114
        seq_ids, sampling_params = seq_group
115
116
117
        if i < input_metadata.num_prompts:
            assert len(seq_ids) == 1, "Prompt input should have only one seq."
            prompt_len = input_metadata.prompt_lens[i]
118
119
120
121
            if sampling_params.prompt_logprobs is not None:
                selected_token_indices.extend(
                    range(start_idx, start_idx + prompt_len - 1))
            selected_token_indices.append(start_idx + prompt_len - 1)
122
            start_idx += input_metadata.max_prompt_len
123
124
        else:
            num_seqs = len(seq_ids)
125
126
            selected_token_indices.extend(
                range(start_idx, start_idx + num_seqs))
127
128
            start_idx += num_seqs

129
130
131
    selected_token_indices = torch.tensor(selected_token_indices,
                                          dtype=torch.long,
                                          device=hidden_states.device)
132
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
133
    return hidden_states.index_select(0, selected_token_indices)
134
135


136
def _get_penalties(
137
        input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
138
139
140
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
141
    for i, seq_group in enumerate(input_metadata.seq_groups):
142
143
144
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
145
146
147
148
149
150
151
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: We do not apply presence and frequency penalties for the
            # prompt token positions where we don't sample new tokens.
            prompt_len = input_metadata.prompt_lens[i]
            presence_penalties += [0] * (prompt_len - 1)
            frequency_penalties += [0] * (prompt_len - 1)
152
153
        presence_penalties += [p] * len(seq_ids)
        frequency_penalties += [f] * len(seq_ids)
154
155
156
    return presence_penalties, frequency_penalties


157
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
158
    output_tokens: List[List[int]] = []
159
160
161
162
163
164
165
166
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need output tokens to
            # compute penalties.
            prompt_len = input_metadata.prompt_lens[i]
            output_tokens.extend([] for _ in range(prompt_len - 1))
167
        for seq_id in seq_ids:
168
169
170
171
172
173
174
175
176
177
178
            seq_data = input_metadata.seq_data[seq_id]
            output_tokens.append(seq_data.output_token_ids)
    return output_tokens


def _apply_penalties(
    logits: torch.Tensor,
    output_tokens: List[List[int]],
    presence_penalties: List[float],
    frequency_penalties: List[float],
) -> torch.Tensor:
179
    num_seqs, vocab_size = logits.shape
180
181
182
183
184
    for i in range(num_seqs):
        if not output_tokens[i]:
            continue
        p = presence_penalties[i]
        f = frequency_penalties[i]
185
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
186
            continue
187
188
189
        break
    else:
        # Return early if all sequences have zero penalties.
190
191
        return logits

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    max_output_len = max(len(tokens) for tokens in output_tokens)
    padded_output_tokens = [
        tokens + [vocab_size] * (max_output_len - len(tokens))
        for tokens in output_tokens
    ]
    output_tokens_tensor = torch.tensor(padded_output_tokens,
                                        dtype=torch.long,
                                        device=logits.device)

    # Compute the bin counts for the output tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
                             device=logits.device)
    bin_counts.scatter_add_(1, output_tokens_tensor,
                            torch.ones_like(output_tokens_tensor))
    bin_counts = bin_counts[:, :vocab_size]  # Remove the padding bin.
209

210
211
212
213
214
215
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
216
217
218

    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
219
220
    logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
221
222
223
    return logits


224
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
225
226
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
227
    for i, seq_group in enumerate(input_metadata.seq_groups):
228
229
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
230
        if temperature < _SAMPLING_EPS:
231
232
233
234
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0
235
236
237
238
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            prompt_len = input_metadata.prompt_lens[i]
            temperatures += [temperature] * (prompt_len - 1)
239
        temperatures += [temperature] * len(seq_ids)
240
241
242
    return temperatures


Woosuk Kwon's avatar
Woosuk Kwon committed
243
def _get_top_p_top_k(
244
    input_metadata: InputMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
245
246
    vocab_size: int,
) -> Tuple[List[float], List[int]]:
247
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
248
    top_ks: List[int] = []
249
    for i, seq_group in enumerate(input_metadata.seq_groups):
250
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
251
252
253
254
255
        top_p = sampling_params.top_p
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
256
257
258
259
260
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            prompt_len = input_metadata.prompt_lens[i]
            top_ps += [top_p] * (prompt_len - 1)
            top_ks += [top_k] * (prompt_len - 1)
261
262
        top_ps += [top_p] * len(seq_ids)
        top_ks += [top_k] * len(seq_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
263
    return top_ps, top_ks
264
265


Woosuk Kwon's avatar
Woosuk Kwon committed
266
def _apply_top_p_top_k(
267
    logits: torch.Tensor,
268
269
    top_ps: List[float],
    top_ks: List[int],
270
) -> torch.Tensor:
271
272
273
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275

    # Apply top-p.
276
277
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
278
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
279
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
282

    # Apply top-k.
    # Create a mask for the top-k elements.
283
284
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
285
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
286
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288

    # Re-sort the probabilities.
289
290
291
292
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
293
294


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
    samples = torch.argmax(logprobs, dim=-1).cpu()
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
        next_token_ids = [samples[sample_idx].item()]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
318
    probs: torch.Tensor,
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
    max_best_of = 1
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        if is_prompt:
            seq_ids, sampling_params = seq_group
            max_best_of = max(max_best_of, sampling_params.best_of)
    random_samples = torch.multinomial(probs,
                                       num_samples=max_best_of,
                                       replacement=True).cpu()
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == probs.size(0)
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
356
    logprobs: torch.Tensor,
357
358
359
360
361
362
363
364
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
365
    # NOTE: Beam search is not vectorized, so its speed can be slower than
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
403
404
405
406
407
408


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
409
) -> List[Tuple[List[int], List[int]]]:
410
    categorized_seq_group_ids = {t: [] for t in SamplingType}
411
    categorized_sample_indices = {t: [] for t in SamplingType}
412
    start_idx = 0
413
414
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
415
        sampling_type = sampling_params.sampling_type
416
417
418
419
420
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need sample, skip
            prompt_len = input_metadata.prompt_lens[i]
            start_idx += prompt_len - 1
421
422
        categorized_seq_group_ids[sampling_type].append(i)
        num_seqs = len(seq_ids)
423
        categorized_sample_indices[sampling_type].extend(
424
425
            range(start_idx, start_idx + num_seqs))
        start_idx += num_seqs
426
427

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
428
429
430
431
    for sampling_type in SamplingType:
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
432
433
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
434
435
436
        if num_tokens == 0:
            continue
        if sampling_type == SamplingType.GREEDY:
437
            category_logprobs = logprobs[sample_indices]
438
439
            sample_results = _greedy_sample(seq_groups, category_logprobs)
        elif sampling_type == SamplingType.RANDOM:
440
            category_probs = probs[sample_indices]
441
442
443
            sample_results = _random_sample(seq_groups, is_prompts,
                                            category_probs)
        elif sampling_type == SamplingType.BEAM:
444
            category_logprobs = logprobs[sample_indices]
445
446
447
            sample_results = _beam_search_sample(seq_groups, is_prompts,
                                                 input_metadata.seq_data,
                                                 category_logprobs)
448
        else:
449
            raise ValueError(f"Unsupported sampling type: {sampling_type}")
450
        sample_results_dict.update(zip(seq_group_ids, sample_results))
451

452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
    sample_results = [
        sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
            zip(input_metadata.seq_groups, sample_results)):
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
            prompt_len = input_metadata.prompt_lens[i]
            prompt_tokens = input_metadata.seq_data[
                seq_ids[0]].prompt_token_ids
481
            batched_logprobs_query_seq_indices.extend(
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
    ]].cpu()

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
            zip(input_metadata.seq_groups, sample_results)):
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
        if (i < input_metadata.num_prompts
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
            prompt_len = input_metadata.prompt_lens[i]
            prompt_tokens = input_metadata.seq_data[
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
                group_prompt_logprobs.append(prompt_logprobs_dict)
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
            group_sample_logprobs.append(sample_logprobs_dict)
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
    input_metadata: InputMetadata,
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
         group_sample_logprobs) in zip(input_metadata.seq_groups,
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
                SequenceOutputs(seq_ids[parent_id], next_token_id, logprobs))
        sampler_output.append(
            SequenceGroupOutputs(seq_outputs, group_prompt_logprobs))
    return sampler_output