sampler.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

8
from vllm.config.model import LogprobsMode
9
from vllm.utils import is_pin_memory_available
10
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
11
from vllm.v1.sample.metadata import SamplingMetadata
12
from vllm.v1.sample.ops.bad_words import apply_bad_words
13
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
14
from vllm.v1.sample.ops.penalties import apply_all_penalties
15
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
16
17
18
19
20

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
21
22
23
24
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

25
    1. If logprobs are requested:
26
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
27
           as the final logprobs to return.
28
        b) If `logprobs_mode` is `raw_logits`, clone the logits
29
30
31
32
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
33
    5. Apply logit processors which are not argmax-invariant,
34
35
36
37
38
39
40
41
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
42
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
43
44
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
45
        c) Apply logit processors which are argmax-invariant, by default
46
47
48
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
49
50
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
51
           return the greedily sampled tokens and logprobs if requested.
52
53
54
55
56
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
57
       `max_num_logprobs` logprobs.
58
59
60
    9. Return the final `SamplerOutput`.
    """

61
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
62
        super().__init__()
63
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
64
        self.pin_memory = is_pin_memory_available()
65
        self.logprobs_mode = logprobs_mode
66

67
68
69
70
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
71
        predict_bonus_token: bool = False,
72
        logprobs_mode_override: LogprobsMode | None = None,
73
    ) -> SamplerOutput:
74
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
75
76
77
78
79
80
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
81
            if logprobs_mode == "raw_logprobs":
82
                raw_logprobs = self.compute_logprobs(logits)
83
            elif logprobs_mode == "raw_logits":
84
                raw_logprobs = logits.clone()
85

86
87
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
88

89
90
91
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
92
        # Sample the next token.
93
94
95
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
96
97
98
99
100
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
101

102
103
104
105
106
107
108
109
110
111
112
113
        if num_logprobs is None:
            logprobs_tensors = None
        elif num_logprobs == -1:
            # Return the full unsorted and unranked logprobs.
            logprobs_tensors = LogprobsTensors(
                torch.empty(0), raw_logprobs, torch.empty(0)
            )
        else:
            # Gather the logprobs and ranks of the topk and sampled token.
            logprobs_tensors = self.gather_logprobs(
                raw_logprobs, num_logprobs, token_ids=sampled
            )
114

115
116
117
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

118
        # These are GPU tensors.
119
        sampler_output = SamplerOutput(
120
121
122
123
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
124
            logprobs_tensors=logprobs_tensors,
125
126
127
        )
        return sampler_output

128
    @staticmethod
129
130
131
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
132
        all_random: bool,
133
134
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
135
136
137
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
138
        return logits.div_(temp.unsqueeze(dim=1))
139

140
141
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
142
143
144
        return logits.argmax(dim=-1).view(-1)

    def sample(
145
146
147
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
148
        logprobs_mode_override: LogprobsMode | None = None,
149
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
150
151
152
153
154
155
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

156
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
157
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
158
159
160
161
162
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
163
164
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
165
                    if logprobs_mode == "processed_logits":
166
                        processed_logprobs = logits
167
                    elif logprobs_mode == "processed_logprobs":
168
169
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
170

171
172
        assert sampling_metadata.temperature is not None

173
        # Apply temperature.
174
175
176
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
177

178
179
180
181
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
182
183

        # Apply top_k and/or top_p.
184
        random_sampled, processed_logprobs = self.topk_topp_sampler(
185
            logits,
186
            sampling_metadata.generators,
187
188
189
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
190

191
        if greedy_sampled is None:
192
            return random_sampled, processed_logprobs
193
194
195
196
197

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
198
            out=greedy_sampled,  # Reuse tensor
199
        )
200
        return sampled, processed_logprobs
201

202
203
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
204
205
        return logits.log_softmax(dim=-1, dtype=torch.float32)

206
    @staticmethod
207
208
209
210
211
212
213
214
215
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
216
          logprobs: (num tokens) x (vocab) tensor
217
218
219
220
221
222
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
223
                     Must be int64.
224
225
226
227
228
229

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
230
        assert token_ids.dtype == torch.int64
231
        # Find the topK values.
232
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
233
234

        # Get with the logprob of the prompt or sampled token.
235
        token_ids = token_ids.unsqueeze(-1)
236
237
238
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
239
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
240
241
242
243
244

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

245
        # Use int32 to reduce the tensor size.
246
247
248
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
249

250
    @staticmethod
251
252
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
253
        spec_token_ids: list[list[int]] | None = None,
254
255
256
257
258
259
260
261
262
263
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
264
265
266
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
267
        predict_bonus_token: bool,
268
    ) -> torch.Tensor:
269
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
270
        any_penalties_or_bad_words = (
271
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
272
273
274
275
276
277
278
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
279
                output_token_ids,
280
                sampling_metadata.spec_token_ids,
281
            )
282

283
        # Apply allowed token ids.
284
        if sampling_metadata.allowed_token_ids_mask is not None:
285
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
286
287

        # Apply bad words exclusion.
288
289
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)
290
291
292
293
294
295
296

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
297
        return logits
298

299
    @staticmethod
300
    def apply_penalties(
301
302
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
303
        output_token_ids: list[list[int]],
304
    ) -> torch.Tensor:
305
306
307
308
309
310
311
312
313
314
315
316
        if sampling_metadata.no_penalties:
            return logits

        assert sampling_metadata.prompt_token_ids is not None
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )