sampler.py 27 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

7
8
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_gather)
9
from vllm.model_executor.sampling_metadata import SamplingMetadata
10
from vllm.sampling_params import SamplingParams, SamplingType
11
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
Zhuohan Li's avatar
Zhuohan Li committed
12
                           SequenceData, SequenceGroupOutput, SequenceOutput)
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Sampler(nn.Module):
18
19
20
21
22
23
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
24
    3. Apply presence, frequency and repetition penalties.
25
26
27
28
29
30
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    def __init__(self, vocab_size: int) -> None:
33
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
39
        hidden_states: torch.Tensor,
40
        sampling_metadata: SamplingMetadata,
41
        embedding_bias: Optional[torch.Tensor] = None,
42
    ) -> SamplerOutput:
43
        # Get the hidden states that we use for sampling.
44
        hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

        # Get the logits for the next tokens.
47
48
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
49

50
        # Apply logits processors (if any).
51
        logits = _apply_logits_processors(logits, sampling_metadata)
52
        # Apply presence and frequency penalties.
ljss's avatar
ljss committed
53
        presence_penalties, frequency_penalties, repetition_penalties = (
54
            _get_penalties(sampling_metadata))
55
56
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
ljss's avatar
ljss committed
57
        assert len(repetition_penalties) == logits.shape[0]
58
59
60
        logits = _apply_penalties(logits, sampling_metadata,
                                  presence_penalties, frequency_penalties,
                                  repetition_penalties)
61

62
        # Apply temperature scaling.
63
        temperatures = _get_temperatures(sampling_metadata)
64
65
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
66
67
68
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
69
70
71
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
72
        # Apply top-p and top-k truncation.
Roy's avatar
Roy committed
73
        top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
74
            sampling_metadata, self.vocab_size)
75
        assert len(top_ps) == len(top_ks) == logits.shape[0]
76
77
78
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
79
80
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

Roy's avatar
Roy committed
81
82
83
84
        do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps)
        if do_min_p:
            logits = _apply_min_p(logits, min_ps)

85
86
87
        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
91

Woosuk Kwon's avatar
Woosuk Kwon committed
92
        # Sample the next tokens.
93
        sample_results = _sample(probs, logprobs, sampling_metadata)
94
95
        # Get the logprobs query results.
        prompt_logprobs, sample_logprobs = _get_logprobs(
96
97
            logprobs, sampling_metadata, sample_results)
        return _build_sampler_output(sample_results, sampling_metadata,
98
                                     prompt_logprobs, sample_logprobs)
99
100


101
102
103
104
105
106
107
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
108
    logits = tensor_model_parallel_all_gather(logits)
109
110
111
112
113
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


114
115
def _prune_hidden_states(
    hidden_states: torch.Tensor,
116
    sampling_metadata: SamplingMetadata,
117
) -> torch.Tensor:
118
    hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
119
120
    return hidden_states.index_select(0,
                                      sampling_metadata.selected_token_indices)
121
122


123
def _get_penalties(
124
    sampling_metadata: SamplingMetadata
ljss's avatar
ljss committed
125
) -> Tuple[List[float], List[float], List[float]]:
126
127
128
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
ljss's avatar
ljss committed
129
    repetition_penalties: List[float] = []
130
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
131
132
133
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
ljss's avatar
ljss committed
134
        r = sampling_params.repetition_penalty
135
        if (i < sampling_metadata.num_prompts
136
137
138
                and sampling_params.prompt_logprobs is not None):
            # NOTE: We do not apply presence and frequency penalties for the
            # prompt token positions where we don't sample new tokens.
139
            prompt_len = sampling_metadata.prompt_lens[i]
140
141
            presence_penalties += [0] * (prompt_len - 1)
            frequency_penalties += [0] * (prompt_len - 1)
ljss's avatar
ljss committed
142
            repetition_penalties += [1] * (prompt_len - 1)
143
144
        presence_penalties += [p] * len(seq_ids)
        frequency_penalties += [f] * len(seq_ids)
ljss's avatar
ljss committed
145
146
        repetition_penalties += [r] * len(seq_ids)
    return presence_penalties, frequency_penalties, repetition_penalties
147
148


149
def _get_prompt_and_output_tokens(
150
    sampling_metadata: SamplingMetadata,
151
152
) -> Tuple[List[List[int]], List[List[int]]]:
    prompt_tokens: List[List[int]] = []
153
    output_tokens: List[List[int]] = []
154
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
155
        seq_ids, sampling_params = seq_group
156
        if (i < sampling_metadata.num_prompts
157
158
159
                and sampling_params.prompt_logprobs is not None):
            # NOTE: prompt token positions do not need output tokens to
            # compute penalties.
160
            prompt_len = sampling_metadata.prompt_lens[i]
161
            prompt_tokens.extend([] for _ in range(prompt_len - 1))
162
            output_tokens.extend([] for _ in range(prompt_len - 1))
163
        for seq_id in seq_ids:
164
            seq_data = sampling_metadata.seq_data[seq_id]
165
            prompt_tokens.append(seq_data.prompt_token_ids)
166
            output_tokens.append(seq_data.output_token_ids)
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    return prompt_tokens, output_tokens


def _get_bin_counts_and_mask(
    logits: torch.Tensor,
    tokens: List[List[int]],
    vocab_size: int,
    num_seqs: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    max_len = max(len(tokens) for tokens in tokens)
    padded_tokens = [
        tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens
    ]
    tokens_tensor = torch.tensor(padded_tokens,
                                 dtype=torch.long,
                                 device=logits.device)

    # Compute the bin counts for the tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
                             device=logits.device)
    bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor))
    bin_counts = bin_counts[:, :vocab_size]
    mask = bin_counts > 0

    return bin_counts, mask
194
195


196
197
198
199
def _apply_logits_processors(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
200
201
    logits_row_idx = 0
    found_logits_processors = False
202
    for seq_ids, sampling_params in sampling_metadata.seq_groups:
203
204
205
206
207
        logits_processors = sampling_params.logits_processors
        if logits_processors:
            found_logits_processors = True
            for seq_id in seq_ids:
                logits_row = logits[logits_row_idx]
208
                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
209
210
211
212
213
214
215
216
217
218
219
                for logits_processor in logits_processors:
                    logits_row = logits_processor(token_ids, logits_row)
                logits[logits_row_idx] = logits_row
                logits_row_idx += 1
        else:
            logits_row_idx += len(seq_ids)
    if found_logits_processors:
        assert logits_row_idx == logits.shape[0]
    return logits


220
221
def _apply_penalties(
    logits: torch.Tensor,
222
    sampling_metadata: SamplingMetadata,
223
224
    presence_penalties: List[float],
    frequency_penalties: List[float],
ljss's avatar
ljss committed
225
    repetition_penalties: List[float],
226
) -> torch.Tensor:
227
    num_seqs, vocab_size = logits.shape
228
229
230
    for i in range(num_seqs):
        p = presence_penalties[i]
        f = frequency_penalties[i]
ljss's avatar
ljss committed
231
232
233
        r = repetition_penalties[i]
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs(
                r - 1.0) < _SAMPLING_EPS:
234
            continue
235
236
237
        break
    else:
        # Return early if all sequences have zero penalties.
238
239
        return logits

240
    prompt_tokens, output_tokens = (
241
        _get_prompt_and_output_tokens(sampling_metadata))
242
243
    assert len(prompt_tokens) == logits.shape[0]
    assert len(output_tokens) == logits.shape[0]
244

245
246
247
248
    prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask(
        logits, prompt_tokens, vocab_size, num_seqs)
    output_bin_counts, output_mask = _get_bin_counts_and_mask(
        logits, output_tokens, vocab_size, num_seqs)
249

ljss's avatar
ljss committed
250
251
252
    repetition_penalties = torch.tensor(repetition_penalties,
                                        dtype=logits.dtype,
                                        device=logits.device)
253
254
255
256
257
258
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
259

ljss's avatar
ljss committed
260
    repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
261
    repetition_penalties[~(prompt_mask | output_mask)] = 1.0
ljss's avatar
ljss committed
262
263
264
    logits = torch.where(logits > 0, logits / repetition_penalties,
                         logits * repetition_penalties)

265
266
    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
267
268
    logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * output_mask
269
270
271
    return logits


272
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
273
274
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
275
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
276
277
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
278
        if temperature < _SAMPLING_EPS:
279
280
281
282
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0
283
        if (i < sampling_metadata.num_prompts
284
                and sampling_params.prompt_logprobs is not None):
285
            prompt_len = sampling_metadata.prompt_lens[i]
286
            temperatures += [temperature] * (prompt_len - 1)
287
        temperatures += [temperature] * len(seq_ids)
288
289
290
    return temperatures


Roy's avatar
Roy committed
291
def _get_top_p_top_k_min_p(
292
    sampling_metadata: SamplingMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
293
    vocab_size: int,
Roy's avatar
Roy committed
294
) -> Tuple[List[float], List[int], List[float]]:
295
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
296
    top_ks: List[int] = []
Roy's avatar
Roy committed
297
    min_ps: List[float] = []
298
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
299
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
300
        top_p = sampling_params.top_p
Roy's avatar
Roy committed
301
        min_p = sampling_params.min_p
Woosuk Kwon's avatar
Woosuk Kwon committed
302
303
304
305
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
306
        if (i < sampling_metadata.num_prompts
307
                and sampling_params.prompt_logprobs is not None):
308
            prompt_len = sampling_metadata.prompt_lens[i]
309
310
            top_ps += [top_p] * (prompt_len - 1)
            top_ks += [top_k] * (prompt_len - 1)
Roy's avatar
Roy committed
311
            min_ps += [min_p] * (prompt_len - 1)
312
313
        top_ps += [top_p] * len(seq_ids)
        top_ks += [top_k] * len(seq_ids)
Roy's avatar
Roy committed
314
315
        min_ps += [min_p] * len(seq_ids)
    return top_ps, top_ks, min_ps
316
317


Woosuk Kwon's avatar
Woosuk Kwon committed
318
def _apply_top_p_top_k(
319
    logits: torch.Tensor,
320
321
    top_ps: List[float],
    top_ks: List[int],
322
) -> torch.Tensor:
323
324
325
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
326
327

    # Apply top-p.
328
329
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
330
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
331
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
334

    # Apply top-k.
    # Create a mask for the top-k elements.
335
336
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
337
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
338
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
339
340

    # Re-sort the probabilities.
341
342
343
344
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
345
346


Roy's avatar
Roy committed
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def _apply_min_p(
    logits: torch.Tensor,
    min_ps: List[float],
) -> torch.Tensor:
    """
    Adapted from
    https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17
    """
    min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device)
    probs = torch.softmax(logits, dim=-1)
    top_probs, _ = probs.max(dim=-1, keepdim=True)
    scaled_min_p = min_p.unsqueeze(dim=1) * top_probs
    tokens_to_remove = probs < scaled_min_p
    logits = logits.masked_fill(tokens_to_remove, -float("inf"))

    return logits


365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
    samples = torch.argmax(logprobs, dim=-1).cpu()
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
        next_token_ids = [samples[sample_idx].item()]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
388
    probs: torch.Tensor,
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
    max_best_of = 1
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        if is_prompt:
            seq_ids, sampling_params = seq_group
            max_best_of = max(max_best_of, sampling_params.best_of)
    random_samples = torch.multinomial(probs,
                                       num_samples=max_best_of,
                                       replacement=True).cpu()
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == probs.size(0)
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
426
    logprobs: torch.Tensor,
427
428
429
430
431
432
433
434
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
435
    # NOTE: Beam search is not vectorized, so its speed can be slower than
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
473
474
475
476
477


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
478
    sampling_metadata: SamplingMetadata,
479
) -> List[Tuple[List[int], List[int]]]:
480
    categorized_seq_group_ids = {t: [] for t in SamplingType}
481
482
    categorized_sample_indices = sampling_metadata.categorized_sample_indices
    for i, seq_group in enumerate(sampling_metadata.seq_groups):
483
        _, sampling_params = seq_group
484
485
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
486
487

    sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
488
489
    for sampling_type in SamplingType:
        seq_group_ids = categorized_seq_group_ids[sampling_type]
490
491
        seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
492
493
        sample_indices = categorized_sample_indices[sampling_type]
        num_tokens = len(sample_indices)
494
495
496
        if num_tokens == 0:
            continue
        if sampling_type == SamplingType.GREEDY:
497
            category_logprobs = logprobs[sample_indices]
498
499
            sample_results = _greedy_sample(seq_groups, category_logprobs)
        elif sampling_type == SamplingType.RANDOM:
500
            category_probs = probs[sample_indices]
501
502
503
            sample_results = _random_sample(seq_groups, is_prompts,
                                            category_probs)
        elif sampling_type == SamplingType.BEAM:
504
            category_logprobs = logprobs[sample_indices]
505
            sample_results = _beam_search_sample(seq_groups, is_prompts,
506
                                                 sampling_metadata.seq_data,
507
                                                 category_logprobs)
508
        else:
509
            raise ValueError(f"Unsupported sampling type: {sampling_type}")
510
        sample_results_dict.update(zip(seq_group_ids, sample_results))
511

512
    sample_results = [
513
514
        sample_results_dict[i]
        for i in range(len(sampling_metadata.seq_groups))
515
516
517
518
519
520
    ]
    return sample_results


def _get_logprobs(
    logprobs: torch.Tensor,
521
    sampling_metadata: SamplingMetadata,
522
523
524
525
526
527
528
529
530
    sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
        int, float]]]]:
    # Prepare query indices
    batched_logprobs_query_seq_indices: List[int] = []
    batched_logprobs_query_token_indices: List[int] = []
    largest_num_logprobs = 0
    sample_idx = 0
    for i, (seq_group, sample_result) in enumerate(
531
            zip(sampling_metadata.seq_groups, sample_results)):
532
533
534
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result
        num_parent_seqs = len(seq_ids)
535
        if (i < sampling_metadata.num_prompts
536
537
538
                and sampling_params.prompt_logprobs is not None):
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.prompt_logprobs)
539
540
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
541
                seq_ids[0]].prompt_token_ids
542
            batched_logprobs_query_seq_indices.extend(
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
                sample_idx + j for j in range(prompt_len - 1))
            batched_logprobs_query_token_indices.extend(
                token_id for token_id in prompt_tokens[1:])
            sample_idx += prompt_len - 1
        batched_logprobs_query_seq_indices.extend(
            [sample_idx + parent_id for parent_id in parent_ids])
        batched_logprobs_query_token_indices.extend(next_token_ids)
        if sampling_params.logprobs is not None:
            largest_num_logprobs = max(largest_num_logprobs,
                                       sampling_params.logprobs)
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)

    # Batched query for logprobs of selected token
    batched_logprobs_query_result = logprobs[[
        batched_logprobs_query_seq_indices,
        batched_logprobs_query_token_indices
    ]].cpu()

    # Batched query for logprobs of topk tokens
    if largest_num_logprobs > 0:
        top_logprobs, top_token_ids = torch.topk(logprobs,
                                                 largest_num_logprobs,
                                                 dim=-1)
        top_logprobs = top_logprobs.cpu()
        top_token_ids = top_token_ids.cpu()
    else:
        top_logprobs, top_token_ids = None, None

    # Gather results
    result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
    result_sample_logprobs: List[SampleLogprobs] = []
    sample_idx = 0
    query_result_idx = 0
    for i, (seq_group, sample_result) in enumerate(
578
            zip(sampling_metadata.seq_groups, sample_results)):
579
580
581
582
        seq_ids, sampling_params = seq_group
        next_token_ids, parent_ids = sample_result

        # Prompt logprobs
583
        if (i < sampling_metadata.num_prompts
584
585
                and sampling_params.prompt_logprobs is not None):
            num_logprobs = sampling_params.prompt_logprobs
586
587
            prompt_len = sampling_metadata.prompt_lens[i]
            prompt_tokens = sampling_metadata.seq_data[
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                seq_ids[0]].prompt_token_ids
            group_prompt_logprobs: PromptLogprobs = [None]
            for token_id in prompt_tokens[1:]:
                prompt_logprobs_dict = {
                    token_id:
                    batched_logprobs_query_result[query_result_idx].item()
                }
                if num_logprobs > 0:
                    prompt_logprobs_dict.update(
                        zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
                            top_logprobs[sample_idx, :num_logprobs].tolist()))
                group_prompt_logprobs.append(prompt_logprobs_dict)
                sample_idx += 1
                query_result_idx += 1
            result_prompt_logprobs.append(group_prompt_logprobs)
        else:
            result_prompt_logprobs.append(None)

        # Sample logprobs
        num_logprobs = sampling_params.logprobs
        if num_logprobs is None:
            num_logprobs = 0
        group_sample_logprobs: SampleLogprobs = []
        for next_token_id, parent_id in zip(next_token_ids, parent_ids):
            sample_logprobs_dict = {
                next_token_id:
                batched_logprobs_query_result[query_result_idx].item()
            }
            query_result_idx += 1
            if num_logprobs > 0:
                sample_logprobs_dict.update(
                    zip(
                        top_token_ids[sample_idx +
                                      parent_id, :num_logprobs].tolist(),
                        top_logprobs[sample_idx +
                                     parent_id, :num_logprobs].tolist()))
            group_sample_logprobs.append(sample_logprobs_dict)
        result_sample_logprobs.append(group_sample_logprobs)
        sample_idx += len(seq_ids)

    return result_prompt_logprobs, result_sample_logprobs


def _build_sampler_output(
    sample_results: List[Tuple[List[int], List[int]]],
633
    sampling_metadata: SamplingMetadata,
634
635
636
637
638
    prompt_logprobs: List[Optional[PromptLogprobs]],
    sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
    sampler_output = []
    for (seq_group, sample_result, group_prompt_logprobs,
639
         group_sample_logprobs) in zip(sampling_metadata.seq_groups,
640
641
642
643
644
645
646
647
648
                                       sample_results, prompt_logprobs,
                                       sample_logprobs):
        seq_ids, _ = seq_group
        next_token_ids, parent_ids = sample_result
        seq_outputs = []
        for parent_id, next_token_id, logprobs in zip(parent_ids,
                                                      next_token_ids,
                                                      group_sample_logprobs):
            seq_outputs.append(
Zhuohan Li's avatar
Zhuohan Li committed
649
                SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
650
        sampler_output.append(
Zhuohan Li's avatar
Zhuohan Li committed
651
            SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
652
    return sampler_output