sampler.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""A layer that samples the next tokens from the model's outputs."""

5
6
from typing import Optional

7
8
9
import torch
import torch.nn as nn

10
from vllm.config.model import LogprobsMode
11
from vllm.utils import is_pin_memory_available
12
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.sample.ops.bad_words import apply_bad_words
15
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
from vllm.v1.sample.ops.penalties import apply_all_penalties
17
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
19
20
21
22

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
23
24
25
26
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

27
    1. If logprobs are requested:
28
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
29
           as the final logprobs to return.
30
        b) If `logprobs_mode` is `raw_logits`, clone the logits
31
32
33
34
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
35
    5. Apply logit processors which are not argmax-invariant,
36
37
38
39
40
41
42
43
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
44
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
45
46
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
47
        c) Apply logit processors which are argmax-invariant, by default
48
49
50
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
51
52
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
53
           return the greedily sampled tokens and logprobs if requested.
54
55
56
57
58
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
59
       `max_num_logprobs` logprobs.
60
61
62
    9. Return the final `SamplerOutput`.
    """

63
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
        super().__init__()
65
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
66
        self.pin_memory = is_pin_memory_available()
67
        self.logprobs_mode = logprobs_mode
68

69
70
71
72
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
73
        predict_bonus_token: bool = False,
74
    ) -> SamplerOutput:
75
76
77
78
79
80
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
81
            if self.logprobs_mode == "raw_logprobs":
82
                raw_logprobs = self.compute_logprobs(logits)
83
            elif self.logprobs_mode == "raw_logits":
84
                raw_logprobs = logits.clone()
85

86
87
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
88

89
90
91
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
92
        # Sample the next token.
93
94
95
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
96
97
98
99
100
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
101
102
103

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
104
105
106
107
108
        logprobs_tensors = (
            None
            if num_logprobs is None
            else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
        )
109

110
111
112
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

113
        # These are GPU tensors.
114
        sampler_output = SamplerOutput(
115
116
117
118
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
119
            logprobs_tensors=logprobs_tensors,
120
121
122
        )
        return sampler_output

123
    @staticmethod
124
125
126
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
127
        all_random: bool,
128
129
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
130
131
132
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
133
        return logits.div_(temp.unsqueeze(dim=1))
134

135
136
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
137
138
139
        return logits.argmax(dim=-1).view(-1)

    def sample(
140
141
142
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
143
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
144
145
146
147
148
149
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

150
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
151
152
153
154
155
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
156
157
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
158
                    if self.logprobs_mode == "processed_logits":
159
                        processed_logprobs = logits
160
                    elif self.logprobs_mode == "processed_logprobs":
161
162
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
163

164
165
        assert sampling_metadata.temperature is not None

166
        # Apply temperature.
167
168
169
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
170

171
172
173
174
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
175
176

        # Apply top_k and/or top_p.
177
        random_sampled, processed_logprobs = self.topk_topp_sampler(
178
            logits,
179
            sampling_metadata.generators,
180
181
182
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
183

184
        if greedy_sampled is None:
185
            return random_sampled, processed_logprobs
186
187
188
189
190

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
191
            out=greedy_sampled,  # Reuse tensor
192
        )
193
        return sampled, processed_logprobs
194

195
196
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
197
198
        return logits.log_softmax(dim=-1, dtype=torch.float32)

199
    @staticmethod
200
201
202
203
204
205
206
207
208
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
209
          logprobs: (num tokens) x (vocab) tensor
210
211
212
213
214
215
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
216
                     Must be int64.
217
218
219
220
221
222

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
223
        assert token_ids.dtype == torch.int64
224
        # Find the topK values.
225
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
226
227

        # Get with the logprob of the prompt or sampled token.
228
        token_ids = token_ids.unsqueeze(-1)
229
230
231
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
232
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
233
234
235
236
237

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

238
        # Use int32 to reduce the tensor size.
239
240
241
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
242

243
    @staticmethod
244
245
246
247
248
249
250
251
252
253
254
255
256
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
        spec_token_ids: Optional[list[list[int]]] = None,
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
257
258
259
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
260
        predict_bonus_token: bool,
261
    ) -> torch.Tensor:
262
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
263
        any_penalties_or_bad_words = (
264
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
265
266
267
268
269
270
271
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
272
                output_token_ids,
273
                sampling_metadata.spec_token_ids,
274
            )
275

276
        # Apply allowed token ids.
277
        if sampling_metadata.allowed_token_ids_mask is not None:
278
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
279
280

        # Apply bad words exclusion.
281
282
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)
283
284
285
286
287
288
289

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
290
        return logits
291

292
    @staticmethod
293
    def apply_penalties(
294
295
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
296
        output_token_ids: list[list[int]],
297
    ) -> torch.Tensor:
298
299
300
301
302
303
304
305
306
307
308
309
        if sampling_metadata.no_penalties:
            return logits

        assert sampling_metadata.prompt_token_ids is not None
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )