sampler.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""A layer that samples the next tokens from the model's outputs."""

5
6
from typing import Optional

7
8
9
import torch
import torch.nn as nn

10
from vllm.config.model import LogprobsMode
11
from vllm.utils import is_pin_memory_available
12
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.sample.ops.bad_words import apply_bad_words
15
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
from vllm.v1.sample.ops.penalties import apply_all_penalties
17
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
19
20
21
22

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
23
24
25
26
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

27
    1. If logprobs are requested:
28
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
29
           as the final logprobs to return.
30
        b) If `logprobs_mode` is `raw_logits`, clone the logits
31
32
33
34
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
35
    5. Apply logit processors which are not argmax-invariant,
36
37
38
39
40
41
42
43
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
44
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
45
46
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
47
        c) Apply logit processors which are argmax-invariant, by default
48
49
50
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
51
52
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
53
           return the greedily sampled tokens and logprobs if requested.
54
55
56
57
58
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
59
       `max_num_logprobs` logprobs.
60
61
62
    9. Return the final `SamplerOutput`.
    """

63
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
        super().__init__()
65
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
66
        self.pin_memory = is_pin_memory_available()
67
        self.logprobs_mode = logprobs_mode
68

69
70
71
72
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
73
        predict_bonus_token: bool = False,
74
    ) -> SamplerOutput:
75
76
77
78
79
80
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
81
            if self.logprobs_mode == "raw_logprobs":
82
                raw_logprobs = self.compute_logprobs(logits)
83
            elif self.logprobs_mode == "raw_logits":
84
                raw_logprobs = logits.clone()
85

86
87
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
88

89
90
91
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
92
        # Sample the next token.
93
94
95
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
96
97
98
99
100
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
101
102
103

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
104
105
106
107
108
        logprobs_tensors = (
            None
            if num_logprobs is None
            else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
        )
109

110
111
112
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

113
        # These are GPU tensors.
114
        sampler_output = SamplerOutput(
115
116
117
118
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
119
            logprobs_tensors=logprobs_tensors,
120
121
122
123
124
125
126
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
127
        all_random: bool,
128
129
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
130
131
132
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
133
        return logits.div_(temp.unsqueeze(dim=1))
134

135
136
137
138
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
139
140
141
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
142
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
143
144
145
146
147
148
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

149
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
150
151
152
153
154
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
155
156
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
157
                    if self.logprobs_mode == "processed_logits":
158
                        processed_logprobs = logits
159
                    elif self.logprobs_mode == "processed_logprobs":
160
161
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
162

163
164
        assert sampling_metadata.temperature is not None

165
        # Apply temperature.
166
167
168
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
169

170
171
172
173
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
174
175

        # Apply top_k and/or top_p.
176
        random_sampled, processed_logprobs = self.topk_topp_sampler(
177
            logits,
178
            sampling_metadata.generators,
179
180
181
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
182

183
        if greedy_sampled is None:
184
            return random_sampled, processed_logprobs
185
186
187
188
189

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
190
            out=greedy_sampled,  # Reuse tensor
191
        )
192
        return sampled, processed_logprobs
193

194
195
196
197
    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
198
        self,
199
200
201
202
203
204
205
206
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
207
          logprobs: (num tokens) x (vocab) tensor
208
209
210
211
212
213
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
214
                     Must be int64.
215
216
217
218
219
220

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
221
        assert token_ids.dtype == torch.int64
222
        # Find the topK values.
223
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
224
225

        # Get with the logprob of the prompt or sampled token.
226
        token_ids = token_ids.unsqueeze(-1)
227
228
229
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
230
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
231
232
233
234
235

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

236
        # Use int32 to reduce the tensor size.
237
238
239
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
240

241
242
243
244
245
246
247
248
249
250
251
252
253
254
    def _combine_outputs_with_spec_tokens(
        self,
        output_token_ids: list[list[int]],
        spec_token_ids: Optional[list[list[int]]] = None,
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
255
256
257
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
258
        predict_bonus_token: bool,
259
    ) -> torch.Tensor:
260
261
262
263
264
265
266
267
268
        any_penalties_or_bad_words = (
            sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
269
                sampling_metadata.output_token_ids,
270
                sampling_metadata.spec_token_ids,
271
            )
272

273
        # Apply allowed token ids.
274
        if sampling_metadata.allowed_token_ids_mask is not None:
275
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292

        # Apply bad words exclusion.
        if sampling_metadata.bad_words_token_ids:
            apply_bad_words(
                logits,
                sampling_metadata.bad_words_token_ids,
                output_token_ids
                if output_token_ids is not None
                else sampling_metadata.output_token_ids,
            )

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
293
        return logits
294

295
    def apply_penalties(
296
297
298
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
299
        output_token_ids: Optional[list[list[int]]] = None,
300
    ) -> torch.Tensor:
301
302
303
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
            logits = apply_all_penalties(
304
                logits,
305
306
307
308
309
310
311
                sampling_metadata.prompt_token_ids,
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
                output_token_ids
                if output_token_ids is not None
                else sampling_metadata.output_token_ids,
312
313
            )
        return logits