sampler.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

8
from vllm import envs
9
from vllm.config.model import LogprobsMode
10
from vllm.utils.platform_utils import is_pin_memory_available
11
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
12
from vllm.v1.sample.metadata import SamplingMetadata
13
from vllm.v1.sample.ops.bad_words import apply_bad_words
14
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
15
from vllm.v1.sample.ops.penalties import apply_all_penalties
16
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
17
18
19
20
21

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
22
23
24
25
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

26
    1. If logprobs are requested:
27
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
28
           as the final logprobs to return.
29
        b) If `logprobs_mode` is `raw_logits`, clone the logits
30
31
32
33
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
34
    5. Apply logit processors which are not argmax-invariant,
35
36
37
38
39
40
41
42
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
43
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
44
45
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
46
        c) Apply logit processors which are argmax-invariant, by default
47
48
49
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
50
51
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
52
           return the greedily sampled tokens and logprobs if requested.
53
54
55
56
57
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
58
       `max_num_logprobs` logprobs.
59
60
61
    9. Return the final `SamplerOutput`.
    """

62
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
63
        super().__init__()
64
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
65
        self.pin_memory = is_pin_memory_available()
66
        self.logprobs_mode = logprobs_mode
67

68
69
70
71
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
72
        predict_bonus_token: bool = False,
73
        logprobs_mode_override: LogprobsMode | None = None,
74
    ) -> SamplerOutput:
75
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
76
77
78
79
80
81
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
82
            if logprobs_mode == "raw_logprobs":
83
                raw_logprobs = self.compute_logprobs(logits)
84
            elif logprobs_mode == "raw_logits":
85
86
87
88
                if logits.dtype == torch.float32:
                    raw_logprobs = logits.clone()
                else:
                    raw_logprobs = logits.to(torch.float32)
89

90
91
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
92

93
94
95
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
96
        # Sample the next token.
97
98
99
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
100
101
102
103
104
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
105

106
107
108
109
110
111
112
113
114
115
116
117
        if num_logprobs is None:
            logprobs_tensors = None
        elif num_logprobs == -1:
            # Return the full unsorted and unranked logprobs.
            logprobs_tensors = LogprobsTensors(
                torch.empty(0), raw_logprobs, torch.empty(0)
            )
        else:
            # Gather the logprobs and ranks of the topk and sampled token.
            logprobs_tensors = self.gather_logprobs(
                raw_logprobs, num_logprobs, token_ids=sampled
            )
118

119
120
121
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

122
        # These are GPU tensors.
123
        sampler_output = SamplerOutput(
124
125
126
127
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
128
            logprobs_tensors=logprobs_tensors,
129
130
131
        )
        return sampler_output

132
    @staticmethod
133
134
135
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
136
        all_random: bool,
137
138
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
139
140
141
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
142
        return logits.div_(temp.unsqueeze(dim=1))
143

144
145
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
146
147
148
        return logits.argmax(dim=-1).view(-1)

    def sample(
149
150
151
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
152
        logprobs_mode_override: LogprobsMode | None = None,
153
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
154
155
156
157
158
159
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

160
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
161
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
162
163
164
165
166
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
167
168
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
169
                    if logprobs_mode == "processed_logits":
170
                        processed_logprobs = logits
171
                    elif logprobs_mode == "processed_logprobs":
172
173
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
174

175
176
        assert sampling_metadata.temperature is not None

177
        # Apply temperature.
178
179
180
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
181

182
183
184
185
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
186
187

        # Apply top_k and/or top_p.
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        if (
            envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
            and sampling_metadata.top_k is not None
            and sampling_metadata.top_p is not None
            and sampling_metadata.max_top_k is not None
            and not sampling_metadata.has_any_no_top_k
            and self.topk_topp_sampler.forward.__name__ == "forward_native"
        ):
            random_sampled, processed_logprobs = self.topk_topp_sampler(
                logits,
                sampling_metadata.generators,
                sampling_metadata.top_k,
                sampling_metadata.top_p,
                max_top_k=sampling_metadata.max_top_k,
                has_any_no_top_k=sampling_metadata.has_any_no_top_k,
            )
        else:
            random_sampled, processed_logprobs = self.topk_topp_sampler(
                logits,
                sampling_metadata.generators,
                sampling_metadata.top_k,
                sampling_metadata.top_p,
            )
211

212
        if greedy_sampled is None:
213
            return random_sampled, processed_logprobs
214
215
216
217
218

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
219
            out=greedy_sampled,  # Reuse tensor
220
        )
221
        return sampled, processed_logprobs
222

223
224
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
225
226
        return logits.log_softmax(dim=-1, dtype=torch.float32)

227
    @staticmethod
228
229
230
231
232
233
234
235
236
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
237
          logprobs: (num tokens) x (vocab) tensor
238
239
240
241
242
243
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
244
                     Must be int64.
245
246
247
248
249
250

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
251
        assert token_ids.dtype == torch.int64
252
        # Find the topK values.
253
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
254
255

        # Get with the logprob of the prompt or sampled token.
256
        token_ids = token_ids.unsqueeze(-1)
257
258
259
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
260
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
261
262
263
264
265

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

266
        # Use int32 to reduce the tensor size.
267
268
269
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
270

271
    @staticmethod
272
273
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
274
        spec_token_ids: list[list[int]] | None = None,
275
276
277
278
279
280
281
282
283
284
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
285
286
287
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
288
        predict_bonus_token: bool,
289
    ) -> torch.Tensor:
290
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
291
        any_penalties_or_bad_words = (
292
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
293
294
295
296
297
298
299
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
300
                output_token_ids,
301
                sampling_metadata.spec_token_ids,
302
            )
303

304
        # Apply allowed token ids.
305
        if sampling_metadata.allowed_token_ids_mask is not None:
306
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
307
308

        # Apply bad words exclusion.
309
310
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)
311
312
313
314
315
316
317

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
318
        return logits
319

320
    @staticmethod
321
    def apply_penalties(
322
323
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
324
        output_token_ids: list[list[int]],
325
    ) -> torch.Tensor:
326
327
328
329
330
331
332
333
334
335
336
337
        if sampling_metadata.no_penalties:
            return logits

        assert sampling_metadata.prompt_token_ids is not None
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )