sampler.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""A layer that samples the next tokens from the model's outputs."""

5
6
from typing import Optional

7
8
9
import torch
import torch.nn as nn

10
from vllm.config import LogprobsMode
11
from vllm.utils import is_pin_memory_available
12
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.sample.ops.bad_words import apply_bad_words
15
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
from vllm.v1.sample.ops.penalties import apply_all_penalties
17
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
19
20
21
22

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
23
24
25
26
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

27
    1. If logprobs are requested:
28
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
29
           as the final logprobs to return.
30
        b) If `logprobs_mode` is `raw_logits`, clone the logits
31
32
33
34
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
35
    5. Apply logit processors which are not argmax-invariant,
36
37
38
39
40
41
42
43
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
44
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
45
46
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
47
        c) Apply logit processors which are argmax-invariant, by default
48
49
50
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
51
52
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
53
           return the greedily sampled tokens and logprobs if requested.
54
55
56
57
58
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
59
       `max_num_logprobs` logprobs.
60
61
62
    9. Return the final `SamplerOutput`.
    """

63
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
        super().__init__()
65
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
66
        self.pin_memory = is_pin_memory_available()
67
        self.logprobs_mode = logprobs_mode
68

69
70
71
72
73
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
74
75
76
77
78
79
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
80
            if self.logprobs_mode == "raw_logprobs":
81
                raw_logprobs = self.compute_logprobs(logits)
82
            elif self.logprobs_mode == "raw_logits":
83
                raw_logprobs = logits.clone()
84

85
86
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
87
88
        # Apply allowed token ids.
        logits = self.apply_allowed_token_ids(logits, sampling_metadata)
89
90
        # Apply bad words exclusion.
        logits = self.apply_bad_words(logits, sampling_metadata)
91
92

        # Apply logits processors which can impact greedy sampling
93
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
94
95
            logits = processor.apply(logits)

96
97
        # Apply penalties (e.g., min_tokens, freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata)
98

99
        # Sample the next token.
100
101
102
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
103
104
105
106
107
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
108
109
110

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
111
112
113
114
115
        logprobs_tensors = (
            None
            if num_logprobs is None
            else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
        )
116

117
118
119
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

120
        # These are GPU tensors.
121
        sampler_output = SamplerOutput(
122
123
124
125
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
126
            logprobs_tensors=logprobs_tensors,
127
128
129
130
131
132
133
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
134
        all_random: bool,
135
136
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
137
138
139
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
140
        return logits.div_(temp.unsqueeze(dim=1))
141

142
143
144
145
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
146
147
148
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
149
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
150
151
152
153
154
155
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

156
        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
157
158
159
160
161
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
162
163
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
164
                    if self.logprobs_mode == "processed_logits":
165
                        processed_logprobs = logits
166
                    elif self.logprobs_mode == "processed_logprobs":
167
168
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
169

170
171
        assert sampling_metadata.temperature is not None

172
        # Apply temperature.
173
174
175
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )
176

177
178
179
180
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
181
182

        # Apply top_k and/or top_p.
183
        random_sampled, processed_logprobs = self.topk_topp_sampler(
184
            logits,
185
            sampling_metadata.generators,
186
187
188
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
189

190
        if greedy_sampled is None:
191
            return random_sampled, processed_logprobs
192
193
194
195
196

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
197
            out=greedy_sampled,  # Reuse tensor
198
        )
199
        return sampled, processed_logprobs
200

201
202
203
204
    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
205
        self,
206
207
208
209
210
211
212
213
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
214
          logprobs: (num tokens) x (vocab) tensor
215
216
217
218
219
220
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
221
                     Must be int64.
222
223
224
225
226
227

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
228
        assert token_ids.dtype == torch.int64
229
        # Find the topK values.
230
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
231
232

        # Get with the logprob of the prompt or sampled token.
233
        token_ids = token_ids.unsqueeze(-1)
234
235
236
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
237
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
238
239
240
241
242

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

243
        # Use int32 to reduce the tensor size.
244
245
246
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
247

248
249
250
251
252
253
254
    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
255
            logits = apply_all_penalties(
256
257
                logits,
                sampling_metadata.prompt_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
261
262
                sampling_metadata.output_token_ids,
            )
263
        return logits
264

265
266
267
268
269
270
    def apply_allowed_token_ids(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.allowed_token_ids_mask is not None:
271
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))
272
        return logits
273
274
275
276
277
278
279
280
281
282
283
284
285

    def apply_bad_words(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.bad_words_token_ids:
            apply_bad_words(
                logits,
                sampling_metadata.bad_words_token_ids,
                sampling_metadata.output_token_ids,
            )
        return logits