sampler.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""A layer that samples the next tokens from the model's outputs."""

5
6
from typing import Optional

7
8
9
import torch
import torch.nn as nn

10
from vllm.config import LogprobsMode
11
from vllm.utils import is_pin_memory_available
12
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
13
from vllm.v1.sample.metadata import SamplingMetadata
14
from vllm.v1.sample.ops.bad_words import apply_bad_words
15
from vllm.v1.sample.ops.logprobs import batched_count_greater_than
16
from vllm.v1.sample.ops.penalties import apply_all_penalties
17
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
18
19
20
21
22

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

    1. If logprobs are requested:  
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
           as the final logprobs to return.  
        b) If `logprobs_mode` is `raw_logits`, clone the logits
           as the final logprobs to return.  
    2. Convert logits to float32.  
    3. Apply allowed token ids whitelist.  
    4. Apply bad words exclusion.  
    5. Apply logit processors which are not argmax-invariant,
       i.e. that can impact greedy sampling.  
        a) Min tokens processor  
        b) Logit bias processor  
    6. Apply penalties  
        a) Repetition penalty  
        b) Frequency penalty  
        c) Presence penalty  
    7. Sample the next tokens. `sample` method performs the following steps:  
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
           return the greedily sampled tokens and final logprobs if requested.  
        b) Apply temperature.  
        c) Apply logit processors which are argmax-invariant, by default
           the min_p processor.  
        d) Apply top_k and/or top_p.  
        e) Sample the next tokens with the probability distribution.  
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
           return the greedily sampled tokens and logprobs if requested.  
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
       `max_num_logprobs` logprobs.  
    9. Return the final `SamplerOutput`.
    """

63
    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
64
        super().__init__()
65
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
Yu Guo's avatar
Yu Guo committed
66
        self.pin_memory = is_pin_memory_available()
67
        self.logprobs_mode = logprobs_mode
68

69
70
71
72
73
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
74
75
76
77
78
79
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
80
            if self.logprobs_mode == "raw_logprobs":
81
                raw_logprobs = self.compute_logprobs(logits)
82
            elif self.logprobs_mode == "raw_logits":
83
                raw_logprobs = logits.clone()
84

85
86
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
87
88
        # Apply allowed token ids.
        logits = self.apply_allowed_token_ids(logits, sampling_metadata)
89
90
        # Apply bad words exclusion.
        logits = self.apply_bad_words(logits, sampling_metadata)
91
92

        # Apply logits processors which can impact greedy sampling
93
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
94
95
            logits = processor.apply(logits)

96
97
        # Apply penalties (e.g., min_tokens, freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata)
98

99
        # Sample the next token.
100
101
102
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
103
104
105
106
107
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()
108
109
110
111
112
113

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
        logprobs_tensors = None if num_logprobs is None else \
            self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)

114
115
116
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

117
        # These are GPU tensors.
118
        sampler_output = SamplerOutput(
119
120
121
122
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
123
            logprobs_tensors=logprobs_tensors,
124
125
126
127
128
129
130
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
131
        all_random: bool,
132
133
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
134
135
136
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
137
        return logits.div_(temp.unsqueeze(dim=1))
138

139
140
141
142
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
143
144
145
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
146
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
147
148
149
150
151
152
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

153
154
        assert not (sampling_metadata.all_greedy
                    and sampling_metadata.all_random)
155
156
157
158
159
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
160
161
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
162
                    if self.logprobs_mode == "processed_logits":
163
                        processed_logprobs = logits
164
                    elif self.logprobs_mode == "processed_logprobs":
165
166
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs
167

168
169
        assert sampling_metadata.temperature is not None

170
        # Apply temperature.
171
172
        logits = self.apply_temperature(logits, sampling_metadata.temperature,
                                        sampling_metadata.all_random)
173

174
175
176
177
        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)
178
179

        # Apply top_k and/or top_p.
180
        random_sampled, processed_logprobs = self.topk_topp_sampler(
181
            logits,
182
            sampling_metadata.generators,
183
184
            sampling_metadata.top_k,
            sampling_metadata.top_p,
185
186
            max_top_k=sampling_metadata.max_top_k,
            has_any_no_top_k=sampling_metadata.has_any_no_top_k,
187
        )
188

189
        if greedy_sampled is None:
190
            return random_sampled, processed_logprobs
191
192
193
194
195

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
196
            out=greedy_sampled,  # Reuse tensor
197
        )
198
        return sampled, processed_logprobs
199

200
201
202
203
    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
204
        self,
205
206
207
208
209
210
211
212
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
Chen1022's avatar
Chen1022 committed
213
          logprobs: (num tokens) x (vocab) tensor
214
215
216
217
218
219
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
220
                     Must be int64.
221
222
223
224
225
226

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
227
        assert token_ids.dtype == torch.int64
228
229
230
231
232
233
        # Find the topK values.
        topk_logprobs, topk_indices = torch.topk(logprobs,
                                                 num_logprobs,
                                                 dim=-1)

        # Get with the logprob of the prompt or sampled token.
234
        token_ids = token_ids.unsqueeze(-1)
235
236
237
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
238
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)
239
240
241
242
243

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

244
        # Use int32 to reduce the tensor size.
245
246
247
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
248

249
250
251
252
253
254
255
    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
256
            logits = apply_all_penalties(
257
258
                logits,
                sampling_metadata.prompt_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
262
263
                sampling_metadata.output_token_ids,
            )
264
        return logits
265

266
267
268
269
270
271
272
273
274
    def apply_allowed_token_ids(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.allowed_token_ids_mask is not None:
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
                                float("-inf"))
        return logits
275
276
277
278
279
280
281
282
283
284
285
286
287

    def apply_bad_words(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.bad_words_token_ids:
            apply_bad_words(
                logits,
                sampling_metadata.bad_words_token_ids,
                sampling_metadata.output_token_ids,
            )
        return logits