sampler.py 16.7 KB
Newer Older
1
"""A layer that samples the next tokens from the model's outputs."""
Zhuohan Li's avatar
Zhuohan Li committed
2
from typing import Dict, List, Tuple, Optional
Woosuk Kwon's avatar
Woosuk Kwon committed
3

4
import numpy as np
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
import torch
import torch.nn as nn

Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import (
10
    gather_from_tensor_model_parallel_region)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceOutputs
Woosuk Kwon's avatar
Woosuk Kwon committed
13

14
_SAMPLING_EPS = 1e-5
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
15

16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
class Sampler(nn.Module):
18
19
20
21
22
23
24
25
26
27
28
29
30
    """Samples the next tokens from the model's outputs.

    This layer does the following:
    1. Discard the hidden states that are not used for sampling (i.e., all
        tokens except the final one in each prompt).
    2. Compute the logits for the next tokens.
    3. Apply presence and frequency penalties.
    4. Apply temperature scaling.
    5. Apply top-p and top-k truncation.
    6. Sample the next tokens.
    Here, each sequence group within the batch can have different sampling
    parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
    """
Woosuk Kwon's avatar
Woosuk Kwon committed
31

Woosuk Kwon's avatar
Woosuk Kwon committed
32
    def __init__(self, vocab_size: int) -> None:
33
        super().__init__()
Woosuk Kwon's avatar
Woosuk Kwon committed
34
        self.vocab_size = vocab_size
Woosuk Kwon's avatar
Woosuk Kwon committed
35
36
37

    def forward(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        embedding: torch.Tensor,
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
        hidden_states: torch.Tensor,
        input_metadata: InputMetadata,
41
        embedding_bias: Optional[torch.Tensor] = None,
42
43
44
    ) -> Dict[int, SequenceOutputs]:
        # Get the hidden states that we use for sampling.
        hidden_states = _prune_hidden_states(hidden_states, input_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46

        # Get the logits for the next tokens.
Woosuk Kwon's avatar
Woosuk Kwon committed
47
        logits = torch.matmul(hidden_states, embedding.t())
48
49
        if embedding_bias is not None:
            logits += embedding_bias
Zhuohan Li's avatar
Zhuohan Li committed
50
        logits = gather_from_tensor_model_parallel_region(logits)
51
        # Remove paddings in vocab (if any).
Woosuk Kwon's avatar
Woosuk Kwon committed
52
        logits = logits[:, :self.vocab_size]
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
55
56
        # Apply presence and frequency penalties.
        output_tokens = _get_output_tokens(input_metadata)
        assert len(output_tokens) == logits.shape[0]
57
58
        presence_penalties, frequency_penalties = _get_penalties(
            input_metadata)
59
60
        assert len(presence_penalties) == logits.shape[0]
        assert len(frequency_penalties) == logits.shape[0]
61
62
        logits = _apply_penalties(logits, output_tokens, presence_penalties,
                                  frequency_penalties, self.vocab_size)
63

64
65
66
67
        # Apply temperature scaling.
        temperatures = _get_temperatures(input_metadata)
        assert len(temperatures) == logits.shape[0]
        if any(t != 1.0 for t in temperatures):
68
69
70
            t = torch.tensor(temperatures,
                             dtype=logits.dtype,
                             device=logits.device)
71
72
73
            # Use in-place division to avoid creating a new tensor.
            logits.div_(t.unsqueeze(dim=1))

Woosuk Kwon's avatar
Woosuk Kwon committed
74
75
        # Apply top-p and top-k truncation.
        top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
76
        assert len(top_ps) == len(top_ks) == logits.shape[0]
77
78
79
        do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
        do_top_k = any(k != self.vocab_size for k in top_ks)
        if do_top_p or do_top_k:
80
81
82
83
84
85
86
            logits = _apply_top_p_top_k(logits, top_ps, top_ks)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        # Compute the log probabilities (before applying top-p and top-k).
        logprobs = torch.log(probs)
87

Woosuk Kwon's avatar
Woosuk Kwon committed
88
        # Sample the next tokens.
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        return _sample(probs, logprobs, input_metadata)


def _prune_hidden_states(
    hidden_states: torch.Tensor,
    input_metadata: InputMetadata,
) -> torch.Tensor:
    start_idx = 0
    last_token_indicies: List[int] = []
    for prompt_len in input_metadata.prompt_lens:
        last_token_indicies.append(start_idx + prompt_len - 1)
        start_idx += prompt_len
    last_token_indicies.extend(
        range(start_idx, start_idx + input_metadata.num_generation_tokens))
103
104
    return hidden_states.index_select(
        0, torch.tensor(last_token_indicies, device=hidden_states.device))
105
106


107
def _get_penalties(
108
        input_metadata: InputMetadata) -> Tuple[List[float], List[float]]:
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    # Collect the presence and frequency penalties.
    presence_penalties: List[float] = []
    frequency_penalties: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        p = sampling_params.presence_penalty
        f = sampling_params.frequency_penalty
        if i < input_metadata.num_prompts:
            # A prompt input.
            presence_penalties.append(p)
            frequency_penalties.append(f)
        else:
            # A generation token.
            presence_penalties += [p] * len(seq_ids)
            frequency_penalties += [f] * len(seq_ids)
    return presence_penalties, frequency_penalties


127
def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    output_tokens: List[List[int]] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, _ = seq_group
        if i < input_metadata.num_prompts:
            # A prompt input.
            # NOTE: While the prompt input usually has no output tokens,
            # it may have output tokens in the case of recomputation.
            seq_id = seq_ids[0]
            seq_data = input_metadata.seq_data[seq_id]
            output_tokens.append(seq_data.output_token_ids)
        else:
            # A generation token.
            for seq_id in seq_ids:
                seq_data = input_metadata.seq_data[seq_id]
                output_tokens.append(seq_data.output_token_ids)
    return output_tokens


def _apply_penalties(
    logits: torch.Tensor,
    output_tokens: List[List[int]],
    presence_penalties: List[float],
    frequency_penalties: List[float],
    vocab_size: int,
) -> torch.Tensor:
    num_seqs = logits.shape[0]
    # Collect the indices of sequences that have non-zero penalties.
    indices = []
    for i in range(num_seqs):
        if not output_tokens[i]:
            continue
        p = presence_penalties[i]
        f = frequency_penalties[i]
161
        if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
            continue
        indices.append(i)

    # Return early if all sequences have zero penalties.
    if not indices:
        return logits

    bin_counts = []
    for i in indices:
        bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
    bin_counts = np.stack(bin_counts, axis=0)
    bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
                                                 device=logits.device)

    frequency_penalties = [frequency_penalties[i] for i in indices]
177
178
179
    frequency_penalties = torch.tensor(frequency_penalties,
                                       dtype=logits.dtype,
                                       device=logits.device)
180
    presence_penalties = [presence_penalties[i] for i in indices]
181
182
183
    presence_penalties = torch.tensor(presence_penalties,
                                      dtype=logits.dtype,
                                      device=logits.device)
184
185
186
187
188
189
190
191
192

    # We follow the definition in OpenAI API.
    # Refer to https://platform.openai.com/docs/api-reference/parameter-details
    logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
    presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
    logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
    return logits


193
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
194
195
196
197
198
    # Collect the temperatures for the logits.
    temperatures: List[float] = []
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        temperature = sampling_params.temperature
199
        if temperature < _SAMPLING_EPS:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
            # NOTE: Zero temperature means deterministic sampling
            # (i.e., greedy sampling or beam search).
            # Set the temperature to 1 to avoid division by zero.
            temperature = 1.0

        if i < input_metadata.num_prompts:
            # A prompt input.
            temperatures.append(temperature)
        else:
            # A generation token.
            temperatures += [temperature] * len(seq_ids)
    return temperatures


Woosuk Kwon's avatar
Woosuk Kwon committed
214
def _get_top_p_top_k(
215
    input_metadata: InputMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
216
217
    vocab_size: int,
) -> Tuple[List[float], List[int]]:
218
    top_ps: List[float] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
219
    top_ks: List[int] = []
220
221
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
Woosuk Kwon's avatar
Woosuk Kwon committed
222
223
224
225
226
        top_p = sampling_params.top_p
        # k should not be greater than the vocab size.
        top_k = min(sampling_params.top_k, vocab_size)
        # k=-1 means no truncation.
        top_k = vocab_size if top_k == -1 else top_k
227
228
        if i < input_metadata.num_prompts:
            # A prompt input.
Woosuk Kwon's avatar
Woosuk Kwon committed
229
230
            top_ps.append(top_p)
            top_ks.append(top_k)
231
232
        else:
            # A generation token.
Woosuk Kwon's avatar
Woosuk Kwon committed
233
234
235
            top_ps += [top_p] * len(seq_ids)
            top_ks += [top_k] * len(seq_ids)
    return top_ps, top_ks
236
237


Woosuk Kwon's avatar
Woosuk Kwon committed
238
def _apply_top_p_top_k(
239
    logits: torch.Tensor,
240
241
    top_ps: List[float],
    top_ks: List[int],
242
) -> torch.Tensor:
243
244
245
    p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device)
    k = torch.tensor(top_ks, dtype=torch.int, device=logits.device)
    logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247

    # Apply top-p.
248
249
    probs_sort = logits_sort.softmax(dim=-1)
    probs_sum = probs_sort.cumsum(dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
250
    top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
251
    logits_sort[top_p_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
252
253
254

    # Apply top-k.
    # Create a mask for the top-k elements.
255
256
    top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
    top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
Woosuk Kwon's avatar
Woosuk Kwon committed
257
    top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
258
    logits_sort[top_k_mask] = -float("inf")
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260

    # Re-sort the probabilities.
261
262
263
264
    logits = torch.gather(logits_sort,
                          dim=-1,
                          index=torch.argsort(logits_idx, dim=-1))
    return logits
265
266
267
268


def _get_topk_logprobs(
    logprobs: torch.Tensor,
Zhuohan Li's avatar
Zhuohan Li committed
269
    num_logprobs: Optional[int],
270
) -> Dict[int, float]:
Zhuohan Li's avatar
Zhuohan Li committed
271
    if num_logprobs is None or num_logprobs == 0:
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        return {}

    topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs)
    if num_logprobs == 1:
        topk_logprobs = [topk_logprobs.item()]
        topk_ids = [topk_ids.item()]
    else:
        topk_logprobs = topk_logprobs.tolist()
        topk_ids = topk_ids.tolist()

    token_to_logprob: Dict[int, float] = {}
    for token_id, logprob in zip(topk_ids, topk_logprobs):
        token_to_logprob[token_id] = logprob
    return token_to_logprob


def _sample_from_prompt(
    prob: torch.Tensor,
    sampling_params: SamplingParams,
) -> List[int]:
    if sampling_params.use_beam_search:
        # Beam search.
294
        beam_width = sampling_params.best_of
295
296
        _, next_token_ids = torch.topk(prob, beam_width)
        next_token_ids = next_token_ids.tolist()
297
    elif sampling_params.temperature < _SAMPLING_EPS:
298
        # Greedy sampling.
299
        assert sampling_params.best_of == 1
300
301
302
        next_token_id = torch.argmax(prob)
        next_token_ids = [next_token_id.item()]
    else:
Woosuk Kwon's avatar
Woosuk Kwon committed
303
        # Random sampling.
304
305
        # Sample `best_of` tokens for the prompt.
        num_seqs = sampling_params.best_of
306
307
308
        next_token_ids = torch.multinomial(prob,
                                           num_samples=num_seqs,
                                           replacement=True)
Woosuk Kwon's avatar
Woosuk Kwon committed
309
        next_token_ids = next_token_ids.tolist()
310
311
312
313
314
315
316
317
318
319
    return next_token_ids


def _sample_from_generation_tokens(
    seq_ids: List[int],
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    seq_logprobs: List[float],
    sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
320
    # NOTE(woosuk): sampling_params.best_of can be greater than
321
322
323
324
325
    # len(seq_ids) because some sequences in the group might have
    # been already terminated.
    if sampling_params.use_beam_search:
        # Beam search.
        # Add cumulative logprobs for the sequences in the group.
326
327
328
        seq_logprobs = torch.tensor(seq_logprobs,
                                    dtype=torch.float,
                                    device=logprobs.device)
329
330
331
332
333
        logprobs = logprobs + seq_logprobs.unsqueeze(dim=1)

        vocab_size = logprobs.size(-1)
        beam_width = len(seq_ids)
        _, topk_ids = torch.topk(logprobs.flatten(), beam_width)
334
335
        topk_ids = topk_ids.tolist()
        seq_idx = [i // vocab_size for i in topk_ids]
336
        beam_seq_ids = [seq_ids[i] for i in seq_idx]
337
        token_ids = [i % vocab_size for i in topk_ids]
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355

        beam_outputs: Dict[int, Tuple[int, int]] = {}
        outstanding_beams: List[Tuple[int, int]] = []
        # If a beam survives, continue with it.
        for seq_id, token_id in zip(beam_seq_ids, token_ids):
            if seq_id not in beam_outputs:
                beam_outputs[seq_id] = (seq_id, token_id)
            else:
                outstanding_beams.append((seq_id, token_id))

        # If a beam is discarded, fork another beam.
        for seq_id in seq_ids:
            if seq_id not in beam_outputs:
                beam_outputs[seq_id] = outstanding_beams.pop()
        assert not outstanding_beams

        parent_seq_ids = [beam_outputs[seq_id][0] for seq_id in seq_ids]
        next_token_ids = [beam_outputs[seq_id][1] for seq_id in seq_ids]
356
    elif sampling_params.temperature < _SAMPLING_EPS:
357
358
359
        # Greedy sampling.
        assert len(seq_ids) == 1
        next_token_id = torch.argmax(probs, dim=-1)
360
        next_token_ids = [int(next_token_id.item())]
361
362
        parent_seq_ids = seq_ids
    else:
Woosuk Kwon's avatar
Woosuk Kwon committed
363
        # Random sampling.
364
        # Sample 1 token for each sequence in the group.
365
366
367
        next_token_ids = torch.multinomial(probs,
                                           num_samples=1,
                                           replacement=True)
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        next_token_ids = next_token_ids.squeeze(dim=-1).tolist()
        parent_seq_ids = seq_ids
    return parent_seq_ids, next_token_ids


def _sample(
    probs: torch.Tensor,
    logprobs: torch.Tensor,
    input_metadata: InputMetadata,
) -> Dict[int, SequenceOutputs]:
    seq_outputs: Dict[int, SequenceOutputs] = {}

    # TODO(woosuk): Optimize.
    idx = 0
    for i, seq_group in enumerate(input_metadata.seq_groups):
        seq_ids, sampling_params = seq_group
        if i < input_metadata.num_prompts:
            # Generate the next tokens for a prompt input.
386
            assert len(seq_ids) == sampling_params.best_of
387
388
389
390
391
392
393
            prob = probs[idx]
            logprob = logprobs[idx]
            idx += 1

            # Sample the next tokens.
            next_token_ids = _sample_from_prompt(prob, sampling_params)
            # Get top-k log probabilities for the next tokens.
394
395
            next_logprobs = _get_topk_logprobs(logprob,
                                               sampling_params.logprobs)
396
397
398
399
400

            # Build the output.
            for seq_id, next_token_id in zip(seq_ids, next_token_ids):
                output_logprobs = next_logprobs.copy()
                output_logprobs[next_token_id] = logprob[next_token_id].item()
401
402
403
                seq_outputs[seq_id] = SequenceOutputs(seq_id, seq_id,
                                                      next_token_id,
                                                      output_logprobs)
404
405
406
407
408
409
410
411
        else:
            # Generate the next tokens for generation tokens.
            prob = probs[idx:idx + len(seq_ids)]
            logprob = logprobs[idx:idx + len(seq_ids)]
            idx += len(seq_ids)

            # Sample the next tokens.
            seq_logprobs = [
412
                input_metadata.seq_data[seq_id].cumulative_logprob
413
414
                for seq_id in seq_ids
            ]
415
416
417
418
419
            parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
                seq_ids, prob, logprob, seq_logprobs, sampling_params)

            # Get top-k log probabilities for the next tokens.
            next_logprobs: Dict[int, Dict[int, float]] = {}
420
            for j, seq_id in enumerate(seq_ids):
421
                next_logprobs[seq_id] = _get_topk_logprobs(
422
                    logprob[j], sampling_params.logprobs)
423
424
425

            # Build the output.
            for seq_id, parent_seq_id, next_token_id in zip(
426
427
                    seq_ids, parent_seq_ids, next_token_ids):
                j = seq_ids.index(parent_seq_id)
428
                output_logprobs = next_logprobs[parent_seq_id].copy()
429
430
                output_logprobs[next_token_id] = logprob[j,
                                                         next_token_id].item()
431
432
433
434
435
436
                seq_outputs[seq_id] = SequenceOutputs(
                    seq_id,
                    parent_seq_id,
                    next_token_id,
                    output_logprobs,
                )
Woosuk Kwon's avatar
Woosuk Kwon committed
437

438
    return seq_outputs