sampler.py 15.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

8
from vllm.config.model import LogprobsMode
9
from vllm.utils.platform_utils import is_pin_memory_available
10
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
11
from vllm.v1.sample.metadata import SamplingMetadata
12
from vllm.v1.sample.ops.bad_words import apply_bad_words
13
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
14
from vllm.v1.sample.ops.penalties import apply_all_penalties
15
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
16
from vllm.v1.worker.gpu.sample.logprob import compute_token_logprobs
17
18
19
20
21

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
22
23
24
25
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

26
    1. If logprobs are requested:
27
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
28
           as the final logprobs to return.
29
        b) If `logprobs_mode` is `raw_logits`, clone the logits
30
31
32
33
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
34
    5. Apply logit processors which are not argmax-invariant,
35
36
37
38
39
40
41
42
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
43
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
44
45
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
46
        c) Apply logit processors which are argmax-invariant, by default
47
48
49
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
50
51
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
52
           return the greedily sampled tokens and logprobs if requested.
53
54
55
56
57
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
58
       `max_num_logprobs` logprobs.
59
60
61
    9. Return the final `SamplerOutput`.
    """

62
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
63
        super().__init__()
64
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
65
        self.pin_memory = is_pin_memory_available()
66
        self.logprobs_mode = logprobs_mode
67

68
69
70
71
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
72
        predict_bonus_token: bool = False,
73
        logprobs_mode_override: LogprobsMode | None = None,
74
    ) -> SamplerOutput:
75
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
76
77
78
79
80
81
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
82
            if logprobs_mode == "raw_logprobs":
83
                raw_logprobs = self.compute_logprobs(logits)
84
            elif logprobs_mode == "raw_logits":
85
86
87
88
                if logits.dtype == torch.float32:
                    raw_logprobs = logits.clone()
                else:
                    raw_logprobs = logits.to(torch.float32)
89

90
91
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
92

93
94
95
        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
96
        # Sample the next token.
97
98
99
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
100
101
102
103
104
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
105

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
106
107
108
109
110
111
112
113
        # Handle logprob_token_ids if specified (more efficient than full vocab)
        # This is used by generative_scoring API to get logprobs for specific tokens
        logprob_token_ids_tensors = None
        if sampling_metadata.logprob_token_ids:
            logprob_token_ids_tensors = self.gather_specific_token_logprobs(
                logits, sampling_metadata.logprob_token_ids, sampled
            )

114
        if num_logprobs is None:
Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
115
            logprobs_tensors = logprob_token_ids_tensors
116
117
118
119
120
121
122
123
124
125
        elif num_logprobs == -1:
            # Return the full unsorted and unranked logprobs.
            logprobs_tensors = LogprobsTensors(
                torch.empty(0), raw_logprobs, torch.empty(0)
            )
        else:
            # Gather the logprobs and ranks of the topk and sampled token.
            logprobs_tensors = self.gather_logprobs(
                raw_logprobs, num_logprobs, token_ids=sampled
            )
126

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
127
128
129
130
131
        # If we have both num_logprobs and logprob_token_ids, prefer
        # logprob_token_ids as it's more specific
        if logprob_token_ids_tensors is not None and num_logprobs is not None:
            logprobs_tensors = logprob_token_ids_tensors

132
133
134
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

135
        # These are GPU tensors.
136
        sampler_output = SamplerOutput(
137
138
139
140
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
141
            logprobs_tensors=logprobs_tensors,
142
143
144
        )
        return sampler_output

Vedant V Jhaveri's avatar
Vedant V Jhaveri committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    def gather_specific_token_logprobs(
        self,
        logits: torch.Tensor,
        logprob_token_ids: dict[int, list[int]],
        sampled: torch.Tensor,
    ) -> LogprobsTensors | None:
        """Compute logprobs for specific token IDs using Triton kernel.

        This method handles heterogeneous token ID lists across requests by
        padding shorter lists to max length and using a fused Triton kernel
        for efficient log_softmax + gather computation.

        Benchmarks show the Triton kernel approach is ~1.4x faster than sparse
        gather for batch sizes > 1 due to the fused kernel reducing memory
        bandwidth requirements.

        Args:
            logits: [batch_size, vocab_size] tensor of logits
            logprob_token_ids: dict mapping req_index -> list of token IDs
            sampled: [batch_size] tensor of sampled token IDs

        Returns:
            LogprobsTensors with logprobs for the specified tokens, or None
            if no requests have logprob_token_ids.
        """
        if not logprob_token_ids:
            return None

        batch_size = logits.shape[0]
        device = logits.device

        # Find max number of tokens across all requests
        max_num_tokens = max(len(tids) for tids in logprob_token_ids.values())

        # Create padded token_ids tensor: [batch_size, max_num_tokens + 1]
        # +1 for sampled token in first position
        token_ids_tensor = torch.zeros(
            batch_size, max_num_tokens + 1, dtype=torch.int64, device=device
        )
        token_ids_tensor[:, 0] = sampled  # First column is sampled token

        # Create mask for valid positions (True = valid, False = padded)
        valid_mask = torch.zeros(
            batch_size, max_num_tokens + 1, dtype=torch.bool, device=device
        )
        valid_mask[:, 0] = True  # Sampled token is always valid

        # Fill in token IDs for each request
        for req_idx, token_ids in logprob_token_ids.items():
            num_tokens = len(token_ids)
            token_ids_tensor[req_idx, 1 : num_tokens + 1] = torch.tensor(
                token_ids, dtype=torch.int64, device=device
            )
            valid_mask[req_idx, 1 : num_tokens + 1] = True

        # Compute logprobs using the fused Triton kernel (log_softmax + gather)
        logprobs = compute_token_logprobs(logits, token_ids_tensor)

        # Mask invalid (padded) positions with -inf
        logprobs = logprobs.masked_fill(~valid_mask, float("-inf"))

        # Compute ranks for the sampled token
        sampled_logits = logits.gather(-1, sampled.unsqueeze(-1))
        token_ranks = (logits > sampled_logits).sum(dim=-1)

        return LogprobsTensors(
            logprob_token_ids=token_ids_tensor.to(torch.int32),
            logprobs=logprobs,
            selected_token_ranks=token_ranks,
        )

216
    @staticmethod
217
218
219
    def apply_temperature(
        logits: torch.Tensor,
        temp: torch.Tensor,
220
        all_random: bool,
221
222
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
223
224
225
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
226
        return logits.div_(temp.unsqueeze(dim=1))
227

228
229
    @staticmethod
    def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
230
231
232
        return logits.argmax(dim=-1).view(-1)

    def sample(
233
234
235
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
236
        logprobs_mode_override: LogprobsMode | None = None,
237
    ) -> tuple[torch.Tensor, torch.Tensor | None]:
238
239
240
241
242
243
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

244
        logprobs_mode = logprobs_mode_override or self.logprobs_mode
245
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
246
247
248
249
250
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
251
252
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
253
                    if logprobs_mode == "processed_logits":
254
                        processed_logprobs = logits
255
                    elif logprobs_mode == "processed_logprobs":
256
257
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
258

259
260
        assert sampling_metadata.temperature is not None

261
        # Apply temperature.
262
263
264
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
265

266
267
268
269
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
270
271

        # Apply top_k and/or top_p.
272
        random_sampled, processed_logprobs = self.topk_topp_sampler(
273
            logits,
274
            sampling_metadata.generators,
275
276
277
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
278

279
        if greedy_sampled is None:
280
            return random_sampled, processed_logprobs
281
282
283
284
285

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
286
            out=greedy_sampled,  # Reuse tensor
287
        )
288
        return sampled, processed_logprobs
289

290
291
    @staticmethod
    def compute_logprobs(logits: torch.Tensor) -> torch.Tensor:
292
293
        return logits.log_softmax(dim=-1, dtype=torch.float32)

294
    @staticmethod
295
296
297
298
299
300
301
302
303
    def gather_logprobs(
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
304
          logprobs: (num tokens) x (vocab) tensor
305
          num_logprobs: maximum number of logprobs to
306
307
308
309
310
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
311
                     Must be int64.
312
313
314
315
316
317

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
318
        assert token_ids.dtype == torch.int64
319
        # Find the topK values.
320
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
321
322

        # Get with the logprob of the prompt or sampled token.
323
        token_ids = token_ids.unsqueeze(-1)
324
325
326
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
327
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
328
329
330
331
332

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

333
        # Use int32 to reduce the tensor size.
334
335
336
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
337

338
    @staticmethod
339
340
    def _combine_outputs_with_spec_tokens(
        output_token_ids: list[list[int]],
341
        spec_token_ids: list[list[int]] | None = None,
342
343
344
345
346
347
348
349
350
351
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
352
353
354
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
355
        predict_bonus_token: bool,
356
    ) -> torch.Tensor:
357
        bad_words_token_ids = sampling_metadata.bad_words_token_ids
358
        any_penalties_or_bad_words = (
359
            bool(bad_words_token_ids) or not sampling_metadata.no_penalties
360
361
362
363
364
365
366
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
367
                output_token_ids,
368
                sampling_metadata.spec_token_ids,
369
            )
370

371
        # Apply allowed token ids.
372
        if sampling_metadata.allowed_token_ids_mask is not None:
373
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
374
375

        # Apply bad words exclusion.
376
377
        if bad_words_token_ids:
            apply_bad_words(logits, bad_words_token_ids, output_token_ids)
378
379
380
381
382
383
384

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
385
        return logits
386

387
    @staticmethod
388
    def apply_penalties(
389
390
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
391
        output_token_ids: list[list[int]],
392
    ) -> torch.Tensor:
393
394
395
396
397
398
399
400
401
402
403
404
        if sampling_metadata.no_penalties:
            return logits

        assert sampling_metadata.prompt_token_ids is not None
        return apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids,
        )