"projects/gemv/run-all.sh" did not exist on "845b2d24266c2ccc519324ba97c0a5ce49b014f9"
sampler.py 20.1 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
2
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
3
4
5
6

import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.model_executor.input_metadata import InputMetadata
8
9
from vllm.model_executor.parallel_utils.communication_op import (
    tensor_model_parallel_all_gather)
10
11
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
12

13
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
14

15

Woosuk Kwon's avatar
Woosuk Kwon committed
16
class Sampler(nn.Module):
17
18
19
20
21
22
23
24
25
26
27
28
29
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
    3. Apply presence and frequency penalties.
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
30

Woosuk Kwon's avatar
Woosuk Kwon committed
31
    def __init__(self, vocab_size: int) -> None:
32
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
33
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
40
        embedding_bias: Optional[torch.Tensor] = None,
41
    ) -> SamplerOutput:
42
43
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45

        # Get the logits for the next tokens.
46
47
        logits = _get_logits(hidden_states, embedding, embedding_bias,
                             self.vocab_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
50
51
        # Apply presence and frequency penalties.
        output_tokens = _get_output_tokens(input_metadata)
        assert len(output_tokens) == logits.shape[0]
52
53
        presence_penalties, frequency_penalties = _get_penalties(
            input_metadata)
54
55
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
56
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
57
                                  frequency_penalties)
58

59
60
61
62
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
63
64
65
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
66
67
68
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
        # Apply top-p and top-k truncation.
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
71
        assert len(top_ps) == len(top_ks) == logits.shape[0]
72
73
74
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
75
76
77
78
79
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
Zhuohan Li's avatar
Zhuohan Li committed
80
81
82
        # Compute the log probabilities.
        # Use log_softmax to ensure numerical stability.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
83

Woosuk Kwon's avatar
Woosuk Kwon committed
84
        # Sample the next tokens.
85
86
87
        return _sample(probs, logprobs, input_metadata)


88
89
90
91
92
93
94
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                embedding_bias: Optional[torch.Tensor],
                vocab_size: int) -> torch.Tensor:
    # Get the logits for the next tokens.
    logits = torch.matmul(hidden_states, embedding.t())
    if embedding_bias is not None:
        logits += embedding_bias
95
    logits = tensor_model_parallel_all_gather(logits)
96
97
98
99
100
    # Remove paddings in vocab (if any).
    logits = logits[:, :vocab_size]
    return logits


101
102
103
104
def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
105
    last_token_indices = []
106
    start_idx = 0
107
    for i, seq_group in enumerate(input_metadata.seq_groups):
108
        seq_ids, _ = seq_group
109
110
111
        if i < input_metadata.num_prompts:
            assert len(seq_ids) == 1, "Prompt input should have only one seq."
            prompt_len = input_metadata.prompt_lens[i]
112
            last_token_indices.append(start_idx + prompt_len - 1)
113
114
115
            start_idx += prompt_len
        else:
            num_seqs = len(seq_ids)
116
            last_token_indices.extend(range(start_idx, start_idx + num_seqs))
117
118
            start_idx += num_seqs

119
120
121
122
    last_token_indices = torch.tensor(last_token_indices,
                                      dtype=torch.long,
                                      device=hidden_states.device)
    return hidden_states.index_select(0, last_token_indices)
123
124


125
def _get_penalties(
126
        input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
127
128
129
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
130
    for seq_group in input_metadata.seq_groups:
131
132
133
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
134
135
        presence_penalties += [p] * len(seq_ids)
        frequency_penalties += [f] * len(seq_ids)
136
137
138
    return presence_penalties, frequency_penalties


139
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
140
    output_tokens: List[List[int]] = []
141
    for seq_group in input_metadata.seq_groups:
142
        seq_ids, _ = seq_group
143
        for seq_id in seq_ids:
144
145
146
147
148
149
150
151
152
153
154
            seq_data = input_metadata.seq_data[seq_id]
            output_tokens.append(seq_data.output_token_ids)
    return output_tokens


def _apply_penalties(
    logits: torch.Tensor,
    output_tokens: List[List[int]],
    presence_penalties: List[float],
    frequency_penalties: List[float],
) -> torch.Tensor:
155
    num_seqs, vocab_size = logits.shape
156
157
158
159
160
    for i in range(num_seqs):
        if not output_tokens[i]:
            continue
        p = presence_penalties[i]
        f = frequency_penalties[i]
161
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
162
            continue
163
164
165
        break
    else:
        # Return early if all sequences have zero penalties.
166
167
        return logits

168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    max_output_len = max(len(tokens) for tokens in output_tokens)
    padded_output_tokens = [
        tokens + [vocab_size] * (max_output_len - len(tokens))
        for tokens in output_tokens
    ]
    output_tokens_tensor = torch.tensor(padded_output_tokens,
                                        dtype=torch.long,
                                        device=logits.device)

    # Compute the bin counts for the output tokens.
    # vocab_size + 1 for padding.
    bin_counts = torch.zeros((num_seqs, vocab_size + 1),
                             dtype=torch.long,
                             device=logits.device)
    bin_counts.scatter_add_(1, output_tokens_tensor,
                            torch.ones_like(output_tokens_tensor))
    bin_counts = bin_counts[:, :vocab_size]  # Remove the padding bin.
185

186
187
188
189
190
191
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
192
193
194

    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
195
196
    logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts
    logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0)
197
198
199
    return logits


200
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
201
202
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
203
    for seq_group in input_metadata.seq_groups:
204
205
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
206
        if temperature < _SAMPLING_EPS:
207
208
209
210
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0
211
        temperatures += [temperature] * len(seq_ids)
212
213
214
    return temperatures


Woosuk Kwon's avatar
Woosuk Kwon committed
215
def _get_top_p_top_k(
216
    input_metadata: InputMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
    vocab_size: int,
) -> Tuple[List[float], List[int]]:
219
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
220
    top_ks: List[int] = []
221
    for seq_group in input_metadata.seq_groups:
222
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
223
224
225
226
227
        top_p = sampling_params.top_p
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
228
229
        top_ps += [top_p] * len(seq_ids)
        top_ks += [top_k] * len(seq_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
230
    return top_ps, top_ks
231
232


Woosuk Kwon's avatar
Woosuk Kwon committed
233
def _apply_top_p_top_k(
234
    logits: torch.Tensor,
235
236
    top_ps: List[float],
    top_ks: List[int],
237
) -> torch.Tensor:
238
239
240
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242

    # Apply top-p.
243
244
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
245
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
246
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249

    # Apply top-k.
    # Create a mask for the top-k elements.
250
251
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
252
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
253
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255

    # Re-sort the probabilities.
256
257
258
259
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
260
261
262
263


def _get_topk_logprobs(
    logprobs: torch.Tensor,
Zhuohan Li's avatar
Zhuohan Li committed
264
    num_logprobs: Optional[int],
265
266
) -> List[Dict[int, float]]:
    num_seqs = logprobs.size(0)
Zhuohan Li's avatar
Zhuohan Li committed
267
    if num_logprobs is None or num_logprobs == 0:
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
        return [{} for _ in range(num_seqs)]

    all_topk_logprobs, all_topk_ids = torch.topk(logprobs,
                                                 num_logprobs,
                                                 dim=-1)
    all_topk_logprobs = all_topk_logprobs.cpu()
    all_topk_ids = all_topk_ids.cpu()
    all_token_to_logprob = []
    for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids):
        token_to_logprob: Dict[int, float] = {}
        for token_id, logprob in zip(topk_ids, topk_logprobs):
            token_to_logprob[token_id.item()] = logprob.item()
        all_token_to_logprob.append(token_to_logprob)
    return all_token_to_logprob


def _build_sequence_outputs(
    parent_ids: List[int],
    next_token_ids: List[int],
287
    selected_token_logprobs: List[float],
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    parent_seq_ids: List[int],
    parent_logprobs: torch.Tensor,
    num_output_logprobs: Optional[int],
) -> List[SequenceOutputs]:
    # Get top-k log probabilities for the next tokens.
    next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs)
    seq_outputs: List[SequenceOutputs] = []
    for parent_id, next_token_id, token_logprob in zip(
            parent_ids, next_token_ids, selected_token_logprobs):
        output_logprobs = next_logprobs[parent_id].copy()
        output_logprobs[next_token_id] = token_logprob
        seq_outputs.append(
            SequenceOutputs(parent_seq_ids[parent_id], next_token_id,
                            output_logprobs))
    return seq_outputs
303
304


305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def _greedy_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
    samples = torch.argmax(logprobs, dim=-1).cpu()
    sample_idx = 0
    results = []
    for seq_group in selected_seq_groups:
        seq_ids, _ = seq_group
        num_parent_seqs = len(seq_ids)
        assert num_parent_seqs == 1, (
            "Greedy sampling should have only one seq.")
        parent_ids = list(range(num_parent_seqs))
        next_token_ids = [samples[sample_idx].item()]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results


def _random_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
328
    probs: torch.Tensor,
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
) -> List[Tuple[List[int], List[int]]]:
    # Find the maximum best_of value of the prompt phase requests.
    max_best_of = 1
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        if is_prompt:
            seq_ids, sampling_params = seq_group
            max_best_of = max(max_best_of, sampling_params.best_of)
    random_samples = torch.multinomial(probs,
                                       num_samples=max_best_of,
                                       replacement=True).cpu()
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * sampling_params.best_of
            next_token_ids = random_samples[
                sample_idx, :sampling_params.best_of].tolist()
        else:
            # Generation phase.
            parent_ids = list(range(num_parent_seqs))
            next_token_ids = random_samples[sample_idx:sample_idx +
                                            num_parent_seqs, 0].tolist()
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == probs.size(0)
    return results


def _beam_search_sample(
    selected_seq_groups: List[Tuple[List[int], SamplingParams]],
    is_prompts: List[bool],
    seq_data: Dict[int, SequenceData],
366
    logprobs: torch.Tensor,
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
) -> List[Tuple[List[int], List[int]]]:
    # We sample 2 * beam_width candidates to make sure that with high
    # probability we can get `beam_width` candidates in addition to
    # the finished sequences for the next iteration. See
    # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
    # for details. See also HF reference:
    # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
    #
    # Note: Beam search is not vectorized, so its speed can be slower than
    # other sampling methods.
    sample_idx = 0
    results = []
    for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
        seq_ids, sampling_params = seq_group
        num_parent_seqs = len(seq_ids)
        beam_width = sampling_params.best_of
        seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
        if is_prompt:
            # Prompt phase.
            assert num_parent_seqs == 1, (
                "Prompt input should have only one seq.")
            parent_ids = [0] * (2 * beam_width)
            _, next_token_ids = torch.topk(seq_group_logprobs[0],
                                           2 * beam_width)
            next_token_ids = next_token_ids.tolist()
        else:
            # Generation phase.
            cumulative_logprobs = [
                seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
            ]
            cumulative_logprobs = torch.tensor(
                cumulative_logprobs,
                dtype=torch.float,
                device=seq_group_logprobs.device)
            seq_group_logprobs = (seq_group_logprobs +
                                  cumulative_logprobs.unsqueeze(dim=1))
            _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
                                     2 * beam_width)
            topk_ids = topk_ids.tolist()
            vocab_size = seq_group_logprobs.size(-1)
            parent_ids = [i // vocab_size for i in topk_ids]
            next_token_ids = [i % vocab_size for i in topk_ids]
        results.append((next_token_ids, parent_ids))
        sample_idx += num_parent_seqs
    assert sample_idx == logprobs.size(0)
    return results
413
414
415
416
417
418


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
419
) -> SamplerOutput:
420
    categorized_seq_group_ids = {t: [] for t in SamplingType}
421
422
    start_idx = 0
    categorized_seq_ids = {t: [] for t in SamplingType}
423
424
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
425
426
427
        sampling_type = sampling_params.sampling_type
        categorized_seq_group_ids[sampling_type].append(i)
        num_seqs = len(seq_ids)
428
429
430
        categorized_seq_ids[sampling_type].extend(
            range(start_idx, start_idx + num_seqs))
        start_idx += num_seqs
431
432
433
434
435
    seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {}
    for sampling_type in SamplingType:
        seq_group_ids = categorized_seq_group_ids[sampling_type]
        seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
        is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
436
        num_tokens = len(categorized_seq_ids[sampling_type])
437
438
        if num_tokens == 0:
            continue
439
440
        category_logprobs = logprobs[categorized_seq_ids[sampling_type]]
        category_probs = probs[categorized_seq_ids[sampling_type]]
441
442
443
444
445
446
447
448
449
        if sampling_type == SamplingType.GREEDY:
            sample_results = _greedy_sample(seq_groups, category_logprobs)
        elif sampling_type == SamplingType.RANDOM:
            sample_results = _random_sample(seq_groups, is_prompts,
                                            category_probs)
        elif sampling_type == SamplingType.BEAM:
            sample_results = _beam_search_sample(seq_groups, is_prompts,
                                                 input_metadata.seq_data,
                                                 category_logprobs)
450
        else:
451
452
453
454
455
456
457
458
459
460
            raise ValueError(f"Unsupported sampling type: {sampling_type}")

        # Batched query for logprobs of selected token
        batched_logprobs_query_seq_indices: List[int] = []
        batched_logprobs_query_token_indices: List[int] = []
        sample_idx = 0
        for seq_group_id, seq_group, sample_result in zip(
                seq_group_ids, seq_groups, sample_results):
            seq_ids, sampling_params = seq_group
            next_token_ids, parent_ids = sample_result
461
            num_parent_seqs = len(seq_ids)
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
            batched_logprobs_query_seq_indices.extend(
                [sample_idx + parent_id for parent_id in parent_ids])
            batched_logprobs_query_token_indices.extend(next_token_ids)
            sample_idx += num_parent_seqs
        assert sample_idx == num_tokens
        batched_logprobs_query_result = category_logprobs[[
            batched_logprobs_query_seq_indices,
            batched_logprobs_query_token_indices
        ]].tolist()

        # Build the sequence outputs.
        sample_idx = 0
        result_idx = 0
        for seq_group_id, seq_group, sample_result in zip(
                seq_group_ids, seq_groups, sample_results):
            seq_ids, sampling_params = seq_group
            next_token_ids, parent_ids = sample_result
            num_results = len(next_token_ids)
            num_parent_seqs = len(seq_ids)
            parent_logprobs = category_logprobs[sample_idx:sample_idx +
                                                num_parent_seqs]
            selected_token_logprobs = batched_logprobs_query_result[
                result_idx:result_idx + num_results]
            seq_output = _build_sequence_outputs(parent_ids, next_token_ids,
                                                 selected_token_logprobs,
                                                 seq_ids, parent_logprobs,
                                                 sampling_params.logprobs)
            seq_outputs_dict[seq_group_id] = seq_output
            sample_idx += num_parent_seqs
            result_idx += num_results
        assert sample_idx == num_tokens

    return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))]