sampler.py 23.7 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

7
8
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_gather)
9
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
10
from vllm.sampling_params import SamplingParams, SamplingType
11
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
Zhuohan Li's avatar
Zhuohan Li committed
12
                           SequenceData, SequenceGroupOutput, SequenceOutput)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
class Sampler(nn.Module):
16
17
18
19
20
21
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
22
    3. Apply presence, frequency and repetition penalties.
23
24
25
26
27
28
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
29

Woosuk Kwon's avatar
Woosuk Kwon committed
30
    def __init__(self, vocab_size: int) -> None:
31
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
32
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        hidden_states: torch.Tensor,
38
        sampling_metadata: SamplingMetadata,
39
        embedding_bias: Optional[torch.Tensor] = None,
40
    ) -> SamplerOutput:
41
        # Get the hidden states that we use for sampling.
42
        hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44

        # Get the logits for the next tokens.
45
46
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
47

48
49
        _, vocab_size = logits.shape

50
        # Apply logits processors (if any).
51
        logits = _apply_logits_processors(logits, sampling_metadata)
52

53
54
55
56
        # Prepare sampling tensors with pinned memory to avoid blocking.
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)
57

58
        # Apply presence and frequency penalties.
59
60
61
62
63
64
        if do_penalties:
            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                      sampling_tensors.output_tokens,
                                      sampling_tensors.presence_penalties,
                                      sampling_tensors.frequency_penalties,
                                      sampling_tensors.repetition_penalties)
65

66
        # Apply temperature scaling.
67
68
69
70
71
72
73
        # Use in-place division to avoid creating a new tensor.
        logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

        if do_top_p_top_k:
            logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
74
        if do_min_p:
75
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
76

77
78
79
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
83

Woosuk Kwon's avatar
Woosuk Kwon committed
84
        # Sample the next tokens.
85
        sample_results = _sample(probs, logprobs, sampling_metadata)
86
87
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
88
89
            logprobs, sampling_metadata, sample_results)
        return _build_sampler_output(sample_results, sampling_metadata,
90
                                     prompt_logprobs, sample_logprobs)
91
92


93
94
95
96
97
98
99
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
100
    logits = tensor_model_parallel_all_gather(logits)
101
102
103
104
105
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


106
107
def _prune_hidden_states(
    hidden_states: torch.Tensor,
108
    sampling_metadata: SamplingMetadata,
109
) -> torch.Tensor:
110
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
111
112
    return hidden_states.index_select(0,
                                      sampling_metadata.selected_token_indices)
113
114


115
def _get_prompt_and_output_tokens(
116
    sampling_metadata: SamplingMetadata,
117
118
) -> Tuple[List[List[int]], List[List[int]]]:
    prompt_tokens: List[List[int]] = []
119
    output_tokens: List[List[int]] = []
120
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
121
        seq_ids, sampling_params = seq_group
122
        if (i < sampling_metadata.num_prompts
123
124
125
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need output tokens to
            # compute penalties.
126
            prompt_len = sampling_metadata.prompt_lens[i]
127
            prompt_tokens.extend([] for _ in range(prompt_len - 1))
128
            output_tokens.extend([] for _ in range(prompt_len - 1))
129
        for seq_id in seq_ids:
130
            seq_data = sampling_metadata.seq_data[seq_id]
131
            prompt_tokens.append(seq_data.prompt_token_ids)
132
            output_tokens.append(seq_data.output_token_ids)
133
134
135
136
    return prompt_tokens, output_tokens


def _get_bin_counts_and_mask(
137
    tokens: torch.Tensor,
138
139
140
141
142
143
144
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
145
146
                             device=tokens.device)
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
147
148
149
150
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
151
152


153
154
155
156
def _apply_logits_processors(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
157
158
    logits_row_idx = 0
    found_logits_processors = False
159
    for seq_ids, sampling_params in sampling_metadata.seq_groups:
160
161
162
163
164
        logits_processors = sampling_params.logits_processors
        if logits_processors:
            found_logits_processors = True
            for seq_id in seq_ids:
                logits_row = logits[logits_row_idx]
165
                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
166
167
168
169
170
171
172
173
174
175
176
                for logits_processor in logits_processors:
                    logits_row = logits_processor(token_ids, logits_row)
                logits[logits_row_idx] = logits_row
                logits_row_idx += 1
        else:
            logits_row_idx += len(seq_ids)
    if found_logits_processors:
        assert logits_row_idx == logits.shape[0]
    return logits


177
178
179
180
181
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                     output_tokens_tensor: torch.Tensor,
                     presence_penalties: torch.Tensor,
                     frequency_penalties: torch.Tensor,
                     repetition_penalties: torch.Tensor) -> torch.Tensor:
182
    num_seqs, vocab_size = logits.shape
183
184
    _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
                                              num_seqs)
185
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
186
        output_tokens_tensor, vocab_size, num_seqs)
187

ljss's avatar
ljss committed
188
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
189
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
190
191
192
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

193
194
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
195
196
    logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
197
198
199
    return logits


Woosuk Kwon's avatar
Woosuk Kwon committed
200
def _apply_top_p_top_k(
201
    logits: torch.Tensor,
202
203
    p: torch.Tensor,
    k: torch.Tensor,
204
) -> torch.Tensor:
205
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
206
207

    # Apply top-p.
208
    probs_sort = logits_sort.softmax(dim=-1)
209
210
    probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
    top_p_mask = probs_sum > p.unsqueeze_(dim=1)
Woosuk Kwon's avatar
Woosuk Kwon committed
211
212
213

    # Apply top-k.
    # Create a mask for the top-k elements.
214
215
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
216
217
218
219
220
    top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)

    # Final mask.
    mask = (top_p_mask | top_k_mask)
    logits_sort.masked_fill_(mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
221
222

    # Re-sort the probabilities.
223
224
225
226
227
228
    src = torch.arange(logits_idx.shape[-1],
                       device=logits_idx.device).expand_as(logits_idx)
    logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
                                                           index=logits_idx,
                                                           src=src)
    logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
229
    return logits
230
231


Roy's avatar
Roy committed
232
233
def _apply_min_p(
    logits: torch.Tensor,
234
    min_p: torch.Tensor,
Roy's avatar
Roy committed
235
236
237
238
239
240
241
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
242
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
243
    tokens_to_remove = probs < scaled_min_p
244
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
245
246
247
248

    return logits


249
250
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
251
    samples: torch.Tensor,
252
) -> List[Tuple[List[int], List[int]]]:
253
    samples = samples.tolist()
254
255
256
257
258
259
260
261
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
262
        next_token_ids = [samples[sample_idx]]
263
264
265
266
267
268
269
270
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
271
    random_samples: torch.Tensor,
272
273
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
274
    random_samples = random_samples.cpu()
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
299
    logprobs: torch.Tensor,
300
301
302
303
304
305
306
307
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
308
    # NOTE: Beam search is not vectorized, so its speed can be slower than
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
346
347


348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
):
    if num_samples > 1:
        # This is equivalent to torch.repeat_interleaved (which also
        # forces a GPU<->CPU sync).
        # This allows us to do sampling with replacement by creating
        # num_samples copies of each row in the tensor, and then
        # batch sampling the resulting tensor.
        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
                                         probs.shape[1]).contiguous().view(
                                             -1, probs.shape[1])
    q = torch.empty_like(probs).exponential_(1)
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


370
371
372
def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
373
    sampling_metadata: SamplingMetadata,
374
) -> List[Tuple[List[int], List[int]]]:
375
    categorized_seq_group_ids = {t: [] for t in SamplingType}
376
377
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
378
        _, sampling_params = seq_group
379
380
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
381
382

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
383
384
385
386
    sample_metadata = {}

    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
387
    for sampling_type in SamplingType:
388
389
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
390
391
        if num_tokens == 0:
            continue
392
393
394
395
396
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
        sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
                                          is_prompts, sample_indices)
397
        if sampling_type == SamplingType.GREEDY:
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
            greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
        elif sampling_type == SamplingType.RANDOM:
            max_best_of = 1
            for seq_group, is_prompt in zip(seq_groups, is_prompts):
                if is_prompt:
                    _, sampling_params = seq_group
                    max_best_of = max(max_best_of, sampling_params.best_of)
            multinomial_samples = _multinomial(probs[sample_indices],
                                               max_best_of)
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    # GPU<->CPU sync happens in the loop below.

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
            sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
421
422
        elif sampling_type == SamplingType.RANDOM:
            sample_results = _random_sample(seq_groups, is_prompts,
423
                                            multinomial_samples)
424
425
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups, is_prompts,
426
                                                 sampling_metadata.seq_data,
427
                                                 beam_search_logprobs)
428
        sample_results_dict.update(zip(seq_group_ids, sample_results))
429

430
    sample_results = [
431
432
        sample_results_dict[i]
        for i in range(len(sampling_metadata.seq_groups))
433
434
435
436
437
438
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
439
    sampling_metadata: SamplingMetadata,
440
441
442
443
444
445
446
447
448
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
449
            zip(sampling_metadata.seq_groups, sample_results)):
450
451
452
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
453
        if (i < sampling_metadata.num_prompts
454
455
456
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
457
458
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
459
                seq_ids[0]].prompt_token_ids
460
            batched_logprobs_query_seq_indices.extend(
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
478
    ]]
479
480
481
482
483
484
485
486
487
488
489

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

490
491
    batched_logprobs_query_result = batched_logprobs_query_result.cpu()

492
493
494
495
496
497
    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
498
            zip(sampling_metadata.seq_groups, sample_results)):
499
500
501
502
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
503
        if (i < sampling_metadata.num_prompts
504
505
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
506
507
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
                group_prompt_logprobs.append(prompt_logprobs_dict)
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
            group_sample_logprobs.append(sample_logprobs_dict)
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
553
    sampling_metadata: SamplingMetadata,
554
555
556
557
558
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
559
         group_sample_logprobs) in zip(sampling_metadata.seq_groups,
560
561
562
563
564
565
566
567
568
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
Zhuohan Li's avatar
Zhuohan Li committed
569
                SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
570
        sampler_output.append(
Zhuohan Li's avatar
Zhuohan Li committed
571
            SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
572
    return sampler_output