sampler.py 24.7 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

7
from vllm.model_executor.parallel_utils.communication_op import (
8
    tensor_model_parallel_gather)
9
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
10
from vllm.sampling_params import SamplingParams, SamplingType
11
12
13
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
                           SamplerOutput, SequenceData, SequenceGroupOutput,
                           SequenceOutput)
14
from vllm.utils import is_neuron
Woosuk Kwon's avatar
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Sampler(nn.Module):
18
19
20
21
22
23
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
24
    3. Apply presence, frequency and repetition penalties.
25
26
27
28
29
30
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
33
34
    def __init__(self,
                 vocab_size: int,
                 org_vocab_size: Optional[int] = None) -> None:
35
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
36
        self.vocab_size = vocab_size
37
38
        # Transformers-neuronx generate outputs as logits directly.
        self.logits_as_hidden_states = is_neuron()
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        # original vocabulary size (without LoRA).
        self.org_vocab_size = org_vocab_size or vocab_size

    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
        # Get the logits for the next tokens.
        logits = torch.matmul(hidden_states, embedding.t())
        if embedding_bias is not None:
            logits += embedding_bias
        logits = tensor_model_parallel_gather(logits)
        # Remove paddings in vocab (if any).
        if logits is not None:
            logits = logits[:, :self.org_vocab_size]
        return logits
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
56
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        hidden_states: torch.Tensor,
58
        sampling_metadata: SamplingMetadata,
59
        embedding_bias: Optional[torch.Tensor] = None,
60
    ) -> Optional[SamplerOutput]:
61
        # Get the hidden states that we use for sampling.
62
63
64
65
66
        if self.logits_as_hidden_states:
            logits = hidden_states
        else:
            hidden_states = _prune_hidden_states(hidden_states,
                                                 sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
67

68
69
            # Get the logits for the next tokens.
            logits = self._get_logits(hidden_states, embedding, embedding_bias)
Woosuk Kwon's avatar
Woosuk Kwon committed
70

71
72
73
74
75
76
77
78
        # Only perform sampling in the driver worker.
        # Note: `_get_logits` is still distributed across TP workers because
        # the `embedding` weight is distributed across TP workers.
        # TODO(zhuohan): Change the get_logits part to a separate stage.
        if not sampling_metadata.perform_sampling:
            return None

        assert logits is not None
79
80
        _, vocab_size = logits.shape

81
        # Apply logits processors (if any).
82
        logits = _apply_logits_processors(logits, sampling_metadata)
83

84
85
86
87
        # Prepare sampling tensors with pinned memory to avoid blocking.
        (sampling_tensors, do_penalties, do_top_p_top_k,
         do_min_p) = SamplingTensors.from_sampling_metadata(
             sampling_metadata, vocab_size, logits.device, logits.dtype)
88

89
        # Apply presence and frequency penalties.
90
91
92
93
94
95
        if do_penalties:
            logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                      sampling_tensors.output_tokens,
                                      sampling_tensors.presence_penalties,
                                      sampling_tensors.frequency_penalties,
                                      sampling_tensors.repetition_penalties)
96

97
        # Apply temperature scaling.
98
99
100
101
        # Use in-place division to avoid creating a new tensor.
        logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))

        if do_top_p_top_k:
102
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
103
104
                                        sampling_tensors.top_ks)

Roy's avatar
Roy committed
105
        if do_min_p:
106
            logits = _apply_min_p(logits, sampling_tensors.min_ps)
Roy's avatar
Roy committed
107

108
109
110
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
111
112
113
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
114

Woosuk Kwon's avatar
Woosuk Kwon committed
115
        # Sample the next tokens.
116
        sample_results = _sample(probs, logprobs, sampling_metadata)
117
118
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
119
120
            logprobs, sampling_metadata, sample_results)
        return _build_sampler_output(sample_results, sampling_metadata,
121
                                     prompt_logprobs, sample_logprobs)
122
123
124
125


def _prune_hidden_states(
    hidden_states: torch.Tensor,
126
    sampling_metadata: SamplingMetadata,
127
) -> torch.Tensor:
128
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
129
130
    return hidden_states.index_select(0,
                                      sampling_metadata.selected_token_indices)
131
132


133
def _get_bin_counts_and_mask(
134
    tokens: torch.Tensor,
135
136
137
138
139
140
141
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
142
143
                             device=tokens.device)
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))
144
145
146
147
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
148
149


150
151
152
153
def _apply_logits_processors(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
154
155
    logits_row_idx = 0
    found_logits_processors = False
156
    for seq_ids, sampling_params in sampling_metadata.seq_groups:
157
158
159
160
161
        logits_processors = sampling_params.logits_processors
        if logits_processors:
            found_logits_processors = True
            for seq_id in seq_ids:
                logits_row = logits[logits_row_idx]
162
                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
163
164
165
166
167
168
169
170
171
172
173
                for logits_processor in logits_processors:
                    logits_row = logits_processor(token_ids, logits_row)
                logits[logits_row_idx] = logits_row
                logits_row_idx += 1
        else:
            logits_row_idx += len(seq_ids)
    if found_logits_processors:
        assert logits_row_idx == logits.shape[0]
    return logits


174
175
176
177
178
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
                     output_tokens_tensor: torch.Tensor,
                     presence_penalties: torch.Tensor,
                     frequency_penalties: torch.Tensor,
                     repetition_penalties: torch.Tensor) -> torch.Tensor:
179
    num_seqs, vocab_size = logits.shape
180
181
    _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size,
                                              num_seqs)
182
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
183
        output_tokens_tensor, vocab_size, num_seqs)
184

ljss's avatar
ljss committed
185
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
186
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
187
188
189
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

190
191
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
192
193
    logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
194
195
196
    return logits


197
def _apply_top_k_top_p(
198
    logits: torch.Tensor,
199
200
    p: torch.Tensor,
    k: torch.Tensor,
201
) -> torch.Tensor:
202
203
204
205
206
207
208
209
    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)

    # Apply top-k.
    top_k_mask = logits_sort.size(1) - k.to(torch.long)
    # Get all the top_k values.
    top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
    top_k_mask = logits_sort < top_k_mask
    logits_sort.masked_fill_(top_k_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211

    # Apply top-p.
212
    probs_sort = logits_sort.softmax(dim=-1)
213
214
215
216
217
    probs_sum = probs_sort.cumsum(dim=-1)
    top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
    # at least one
    top_p_mask[:, -1] = False
    logits_sort.masked_fill_(top_p_mask, -float("inf"))
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219

    # Re-sort the probabilities.
220
221
222
223
224
225
    src = torch.arange(logits_idx.shape[-1],
                       device=logits_idx.device).expand_as(logits_idx)
    logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
                                                           index=logits_idx,
                                                           src=src)
    logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
226
    return logits
227
228


Roy's avatar
Roy committed
229
230
def _apply_min_p(
    logits: torch.Tensor,
231
    min_p: torch.Tensor,
Roy's avatar
Roy committed
232
233
234
235
236
237
238
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
239
    scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs
Roy's avatar
Roy committed
240
    tokens_to_remove = probs < scaled_min_p
241
    logits = logits.masked_fill_(tokens_to_remove, -float("inf"))
Roy's avatar
Roy committed
242
243
244
245

    return logits


246
247
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
248
    samples: torch.Tensor,
249
) -> List[Tuple[List[int], List[int]]]:
250
    samples = samples.tolist()
251
252
253
254
255
256
257
258
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
259
        next_token_ids = [samples[sample_idx]]
260
261
262
263
264
265
266
267
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
268
    random_samples: torch.Tensor,
269
270
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
271
    random_samples = random_samples.cpu()
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
296
    logprobs: torch.Tensor,
297
298
299
300
301
302
303
304
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
305
    # NOTE: Beam search is not vectorized, so its speed can be slower than
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
343
344


345
346
347
348
349
350
351
352
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
def _multinomial(
    probs: torch.Tensor,
    num_samples: int,
Nick Hill's avatar
Nick Hill committed
353
354
355
    seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
    generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor:
356
357
358
359
360
361
362
363
364
    if num_samples > 1:
        # This is equivalent to torch.repeat_interleaved (which also
        # forces a GPU<->CPU sync).
        # This allows us to do sampling with replacement by creating
        # num_samples copies of each row in the tensor, and then
        # batch sampling the resulting tensor.
        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
                                         probs.shape[1]).contiguous().view(
                                             -1, probs.shape[1])
Nick Hill's avatar
Nick Hill committed
365
366
367
368
369
370
371
372
373
    q = torch.empty_like(probs)
    if seq_groups is None:
        q.exponential_()
    else:
        sample_idx = 0
        for (seq_ids, _), generator in zip(seq_groups, generators):
            next_sample_idx = sample_idx + len(seq_ids) * num_samples
            q[sample_idx:next_sample_idx].exponential_(generator=generator)
            sample_idx = next_sample_idx
374
375
376
    return probs.div_(q).argmax(dim=1).view(-1, num_samples)


377
378
379
def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
380
    sampling_metadata: SamplingMetadata,
381
) -> List[Tuple[List[int], List[int]]]:
382
    categorized_seq_group_ids = {t: [] for t in SamplingType}
383
384
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
385
        _, sampling_params = seq_group
386
387
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
388
389

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
390
    sample_metadata = {}
Nick Hill's avatar
Nick Hill committed
391
    multinomial_samples = {}
392
393
394

    # Counterintiutively, having two loops here is actually faster.
    # The first loop can run without waiting on GPU<->CPU sync.
395
    for sampling_type in SamplingType:
396
397
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
398
399
        if num_tokens == 0:
            continue
400
401
402
403
404
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
        sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
                                          is_prompts, sample_indices)
405
        if sampling_type == SamplingType.GREEDY:
406
407
            greedy_samples = torch.argmax(logprobs[sample_indices.long()],
                                          dim=-1)
Nick Hill's avatar
Nick Hill committed
408
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
409
410
411
412
413
            max_best_of = 1
            for seq_group, is_prompt in zip(seq_groups, is_prompts):
                if is_prompt:
                    _, sampling_params = seq_group
                    max_best_of = max(max_best_of, sampling_params.best_of)
Nick Hill's avatar
Nick Hill committed
414
415
416
417
418
            seeded_args = {} if sampling_type == SamplingType.RANDOM else {
                "seq_groups": seq_groups,
                "generators": sampling_metadata.generators,
            }
            multinomial_samples[sampling_type] = _multinomial(
419
                probs[sample_indices.long()], max_best_of, **seeded_args)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        elif sampling_type == SamplingType.BEAM:
            beam_search_logprobs = logprobs[sample_indices]
        else:
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

    # GPU<->CPU sync happens in the loop below.

    for sampling_type in SamplingType:
        if sampling_type not in sample_metadata:
            continue
        seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
            sampling_type]
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, greedy_samples)
Nick Hill's avatar
Nick Hill committed
434
        elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
435
            sample_results = _random_sample(seq_groups, is_prompts,
Nick Hill's avatar
Nick Hill committed
436
                                            multinomial_samples[sampling_type])
437
438
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups, is_prompts,
439
                                                 sampling_metadata.seq_data,
440
                                                 beam_search_logprobs)
441
        sample_results_dict.update(zip(seq_group_ids, sample_results))
442

443
    sample_results = [
444
445
        sample_results_dict[i]
        for i in range(len(sampling_metadata.seq_groups))
446
447
448
449
450
451
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
452
    sampling_metadata: SamplingMetadata,
453
454
455
456
457
458
459
460
461
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
462
            zip(sampling_metadata.seq_groups, sample_results)):
463
464
465
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
466
        if (i < sampling_metadata.num_prompts
467
468
469
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
470
471
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
472
                seq_ids[0]].prompt_token_ids
473
            batched_logprobs_query_seq_indices.extend(
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
491
    ]]
492
493
494
495
496
497
498
499
500
501
502

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

503
504
    batched_logprobs_query_result = batched_logprobs_query_result.cpu()

505
506
507
508
509
510
    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
511
            zip(sampling_metadata.seq_groups, sample_results)):
512
513
514
515
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
516
        if (i < sampling_metadata.num_prompts
517
518
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
519
520
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
521
522
523
524
525
526
527
528
529
530
531
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
532
533
534
535
                group_prompt_logprobs.append({
                    token_id: Logprob(logprob)
                    for token_id, logprob in prompt_logprobs_dict.items()
                })
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
560
561
562
563
            group_sample_logprobs.append({
                token_id: Logprob(logprob)
                for token_id, logprob in sample_logprobs_dict.items()
            })
564
565
566
567
568
569
570
571
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
572
    sampling_metadata: SamplingMetadata,
573
574
575
576
577
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
578
         group_sample_logprobs) in zip(sampling_metadata.seq_groups,
579
580
581
582
583
584
585
586
587
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
Zhuohan Li's avatar
Zhuohan Li committed
588
                SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
589
        sampler_output.append(
Zhuohan Li's avatar
Zhuohan Li committed
590
            SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
591
    return sampler_output