sampler.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

8
from vllm.config.model import LogprobsMode
9
from vllm.utils import is_pin_memory_available
10
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
11
from vllm.v1.sample.metadata import SamplingMetadata
12
from vllm.v1.sample.ops.bad_words import apply_bad_words
13
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
14
from vllm.v1.sample.ops.penalties import apply_all_penalties
15
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
16
17
18
19
20

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
21
22
23
24
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

25
    1. If logprobs are requested:
26
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
27
           as the final logprobs to return.
28
        b) If `logprobs_mode` is `raw_logits`, clone the logits
29
30
31
32
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
33
    5. Apply logit processors which are not argmax-invariant,
34
35
36
37
38
39
40
41
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
42
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
43
44
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
45
        c) Apply logit processors which are argmax-invariant, by default
46
47
48
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
49
50
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
51
           return the greedily sampled tokens and logprobs if requested.
52
53
54
55
56
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
57
       `max_num_logprobs` logprobs.
58
59
60
    9. Return the final `SamplerOutput`.
    """

61
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
62
        super().__init__()
63
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
64
        self.pin_memory = is_pin_memory_available()
65
        self.logprobs_mode = logprobs_mode
66

67
68
69
70
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
71
        predict_bonus_token: bool = False,
72
    ) -> SamplerOutput:
73
74
75
76
77
78
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
79
            if self.logprobs_mode == "raw_logprobs":
80
                raw_logprobs = self.compute_logprobs(logits)
81
            elif self.logprobs_mode == "raw_logits":
82
                raw_logprobs = logits.clone()
83

84
85
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
86

87
88
89
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
90
        # Sample the next token.
91
92
93
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
94
95
96
97
98
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
99
100
101

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
102
103
104
105
106
        logprobs_tensors = (
            None
            if num_logprobs is None
            else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
        )
107

108
109
110
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

111
        # These are GPU tensors.
112
        sampler_output = SamplerOutput(
113
114
115
116
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
117
            logprobs_tensors=logprobs_tensors,
118
119
120
        )
        return sampler_output

121
    @staticmethod
122
123
124
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
125
        all_random: bool,
126
127
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
128
129
130
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
131
        return logits.div_(temp.unsqueeze(dim=1))
132

133
134
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
135
136
137
        return logits.argmax(dim=-1).view(-1)

    def sample(
138
139
140
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
141
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
142
143
144
145
146
147
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

148
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
149
150
151
152
153
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
154
155
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
156
                    if self.logprobs_mode == "processed_logits":
157
                        processed_logprobs = logits
158
                    elif self.logprobs_mode == "processed_logprobs":
159
160
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
161

162
163
        assert sampling_metadata.temperature is not None

164
        # Apply temperature.
165
166
167
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
168

169
170
171
172
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
173
174

        # Apply top_k and/or top_p.
175
        random_sampled, processed_logprobs = self.topk_topp_sampler(
176
            logits,
177
            sampling_metadata.generators,
178
179
180
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
181

182
        if greedy_sampled is None:
183
            return random_sampled, processed_logprobs
184
185
186
187
188

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
189
            out=greedy_sampled,  # Reuse tensor
190
        )
191
        return sampled, processed_logprobs
192

193
194
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
195
196
        return logits.log_softmax(dim=-1, dtype=torch.float32)

197
    @staticmethod
198
199
200
201
202
203
204
205
206
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
207
          logprobs: (num tokens) x (vocab) tensor
208
209
210
211
212
213
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
214
                     Must be int64.
215
216
217
218
219
220

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
221
        assert token_ids.dtype == torch.int64
222
        # Find the topK values.
223
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
224
225

        # Get with the logprob of the prompt or sampled token.
226
        token_ids = token_ids.unsqueeze(-1)
227
228
229
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
230
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
231
232
233
234
235

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

236
        # Use int32 to reduce the tensor size.
237
238
239
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
240

241
    @staticmethod
242
243
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
244
        spec_token_ids: list[list[int]] | None = None,
245
246
247
248
249
250
251
252
253
254
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
255
256
257
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
258
        predict_bonus_token: bool,
259
    ) -> torch.Tensor:
260
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
261
        any_penalties_or_bad_words = (
262
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
263
264
265
266
267
268
269
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
270
                output_token_ids,
271
                sampling_metadata.spec_token_ids,
272
            )
273

274
        # Apply allowed token ids.
275
        if sampling_metadata.allowed_token_ids_mask is not None:
276
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
277
278

        # Apply bad words exclusion.
279
280
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)
281
282
283
284
285
286
287

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
288
        return logits
289

290
    @staticmethod
291
    def apply_penalties(
292
293
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
294
        output_token_ids: list[list[int]],
295
    ) -> torch.Tensor:
296
297
298
299
300
301
302
303
304
305
306
307
        if sampling_metadata.no_penalties:
            return logits

        assert sampling_metadata.prompt_token_ids is not None
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )