sampler.py 23.2 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.model_executor.parallel_utils.communication_op import (
8
    tensor_model_parallel_gather)
9
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
10
from vllm.sampling_params import SamplingParams, SamplingType
11
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
Zhuohan Li's avatar
Zhuohan Li committed
12
                           SequenceData, SequenceGroupOutput, SequenceOutput)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
class Sampler(nn.Module):
16
17
18
19
20
21
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
22
    3. Apply presence, frequency and repetition penalties.
23
24
25
26
27
28
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
29

Woosuk Kwon's avatar
Woosuk Kwon committed
30
    def __init__(self, vocab_size: int) -> None:
31
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        hidden_states: torch.Tensor,
38
        sampling_metadata: SamplingMetadata,
39
        embedding_bias: Optional[torch.Tensor] = None,
40
    ) -> Optional[SamplerOutput]:
41
        # Get the hidden states that we use for sampling.
42
        hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44

        # Get the logits for the next tokens.
45
46
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
47

48
49
50
51
52
53
54
55
        # Only perform sampling in the driver worker.
        # Note: `_get_logits` is still distributed across TP workers because
        # the `embedding` weight is distributed across TP workers.
        # TODO(zhuohan): Change the get_logits part to a separate stage.
        if not sampling_metadata.perform_sampling:
            return None

        assert logits is not None
56
57
        _, vocab_size = logits.shape

58
        # Apply logits processors (if any).
59
        logits = _apply_logits_processors(logits, sampling_metadata)
60

61
62
63
64
        # Prepare sampling tensors with pinned memory to avoid blocking.
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)
65

66
        # Apply presence and frequency penalties.
67
68
69
70
71
72
        if do_penalties:
            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                      sampling_tensors.output_tokens,
                                      sampling_tensors.presence_penalties,
                                      sampling_tensors.frequency_penalties,
                                      sampling_tensors.repetition_penalties)
73

74
        # Apply temperature scaling.
75
76
77
78
        # Use in-place division to avoid creating a new tensor.
        logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

        if do_top_p_top_k:
79
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
80
81
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
82
        if do_min_p:
83
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
84

85
86
87
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
91

Woosuk Kwon's avatar
Woosuk Kwon committed
92
        # Sample the next tokens.
93
        sample_results = _sample(probs, logprobs, sampling_metadata)
94
95
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
96
97
            logprobs, sampling_metadata, sample_results)
        return _build_sampler_output(sample_results, sampling_metadata,
98
                                     prompt_logprobs, sample_logprobs)
99
100


101
102
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
103
                vocab_size: int) -> Optional[torch.Tensor]:
104
105
106
107
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
108
    logits = tensor_model_parallel_gather(logits)
109
    # Remove paddings in vocab (if any).
110
111
    if logits is not None:
        logits = logits[:, :vocab_size]
112
113
114
    return logits


115
116
def _prune_hidden_states(
    hidden_states: torch.Tensor,
117
    sampling_metadata: SamplingMetadata,
118
) -> torch.Tensor:
119
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
120
121
    return hidden_states.index_select(0,
                                      sampling_metadata.selected_token_indices)
122
123


124
def _get_bin_counts_and_mask(
125
    tokens: torch.Tensor,
126
127
128
129
130
131
132
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
133
134
                             device=tokens.device)
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
135
136
137
138
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
139
140


141
142
143
144
def _apply_logits_processors(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
145
146
    logits_row_idx = 0
    found_logits_processors = False
147
    for seq_ids, sampling_params in sampling_metadata.seq_groups:
148
149
150
151
152
        logits_processors = sampling_params.logits_processors
        if logits_processors:
            found_logits_processors = True
            for seq_id in seq_ids:
                logits_row = logits[logits_row_idx]
153
                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
154
155
156
157
158
159
160
161
162
163
164
                for logits_processor in logits_processors:
                    logits_row = logits_processor(token_ids, logits_row)
                logits[logits_row_idx] = logits_row
                logits_row_idx += 1
        else:
            logits_row_idx += len(seq_ids)
    if found_logits_processors:
        assert logits_row_idx == logits.shape[0]
    return logits


165
166
167
168
169
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                     output_tokens_tensor: torch.Tensor,
                     presence_penalties: torch.Tensor,
                     frequency_penalties: torch.Tensor,
                     repetition_penalties: torch.Tensor) -> torch.Tensor:
170
    num_seqs, vocab_size = logits.shape
171
172
    _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
                                              num_seqs)
173
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
174
        output_tokens_tensor, vocab_size, num_seqs)
175

ljss's avatar
ljss committed
176
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
177
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
178
179
180
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

181
182
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
183
184
    logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
185
186
187
    return logits


188
def _apply_top_k_top_p(
189
    logits: torch.Tensor,
190
191
    p: torch.Tensor,
    k: torch.Tensor,
192
) -> torch.Tensor:
193
194
195
196
197
198
199
200
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
201
202

    # Apply top-p.
203
    probs_sort = logits_sort.softmax(dim=-1)
204
205
206
207
208
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210

    # Re-sort the probabilities.
211
212
213
214
215
216
    src = torch.arange(logits_idx.shape[-1],
                       device=logits_idx.device).expand_as(logits_idx)
    logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
                                                           index=logits_idx,
                                                           src=src)
    logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
217
    return logits
218
219


Roy's avatar
Roy committed
220
221
def _apply_min_p(
    logits: torch.Tensor,
222
    min_p: torch.Tensor,
Roy's avatar
Roy committed
223
224
225
226
227
228
229
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
230
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
231
    tokens_to_remove = probs < scaled_min_p
232
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
233
234
235
236

    return logits


237
238
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
239
    samples: torch.Tensor,
240
) -> List[Tuple[List[int], List[int]]]:
241
    samples = samples.tolist()
242
243
244
245
246
247
248
249
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
250
        next_token_ids = [samples[sample_idx]]
251
252
253
254
255
256
257
258
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
259
    random_samples: torch.Tensor,
260
261
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
262
    random_samples = random_samples.cpu()
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
287
    logprobs: torch.Tensor,
288
289
290
291
292
293
294
295
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
296
    # NOTE: Beam search is not vectorized, so its speed can be slower than
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
334
335


336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
):
    if num_samples > 1:
        # This is equivalent to torch.repeat_interleaved (which also
        # forces a GPU<->CPU sync).
        # This allows us to do sampling with replacement by creating
        # num_samples copies of each row in the tensor, and then
        # batch sampling the resulting tensor.
        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
                                         probs.shape[1]).contiguous().view(
                                             -1, probs.shape[1])
    q = torch.empty_like(probs).exponential_(1)
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


358
359
360
def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
361
    sampling_metadata: SamplingMetadata,
362
) -> List[Tuple[List[int], List[int]]]:
363
    categorized_seq_group_ids = {t: [] for t in SamplingType}
364
365
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
366
        _, sampling_params = seq_group
367
368
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
369
370

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
371
372
373
374
    sample_metadata = {}

    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
375
    for sampling_type in SamplingType:
376
377
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
378
379
        if num_tokens == 0:
            continue
380
381
382
383
384
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
        sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
                                          is_prompts, sample_indices)
385
        if sampling_type == SamplingType.GREEDY:
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
            greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
        elif sampling_type == SamplingType.RANDOM:
            max_best_of = 1
            for seq_group, is_prompt in zip(seq_groups, is_prompts):
                if is_prompt:
                    _, sampling_params = seq_group
                    max_best_of = max(max_best_of, sampling_params.best_of)
            multinomial_samples = _multinomial(probs[sample_indices],
                                               max_best_of)
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    # GPU<->CPU sync happens in the loop below.

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
            sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
409
410
        elif sampling_type == SamplingType.RANDOM:
            sample_results = _random_sample(seq_groups, is_prompts,
411
                                            multinomial_samples)
412
413
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups, is_prompts,
414
                                                 sampling_metadata.seq_data,
415
                                                 beam_search_logprobs)
416
        sample_results_dict.update(zip(seq_group_ids, sample_results))
417

418
    sample_results = [
419
420
        sample_results_dict[i]
        for i in range(len(sampling_metadata.seq_groups))
421
422
423
424
425
426
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
427
    sampling_metadata: SamplingMetadata,
428
429
430
431
432
433
434
435
436
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
437
            zip(sampling_metadata.seq_groups, sample_results)):
438
439
440
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
441
        if (i < sampling_metadata.num_prompts
442
443
444
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
445
446
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
447
                seq_ids[0]].prompt_token_ids
448
            batched_logprobs_query_seq_indices.extend(
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
466
    ]]
467
468
469
470
471
472
473
474
475
476
477

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

478
479
    batched_logprobs_query_result = batched_logprobs_query_result.cpu()

480
481
482
483
484
485
    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
486
            zip(sampling_metadata.seq_groups, sample_results)):
487
488
489
490
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
491
        if (i < sampling_metadata.num_prompts
492
493
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
494
495
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
                group_prompt_logprobs.append(prompt_logprobs_dict)
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
            group_sample_logprobs.append(sample_logprobs_dict)
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
541
    sampling_metadata: SamplingMetadata,
542
543
544
545
546
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
547
         group_sample_logprobs) in zip(sampling_metadata.seq_groups,
548
549
550
551
552
553
554
555
556
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
Zhuohan Li's avatar
Zhuohan Li committed
557
                SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
558
        sampler_output.append(
Zhuohan Li's avatar
Zhuohan Li committed
559
            SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
560
    return sampler_output