sampler.py 8.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

7
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
8
from vllm.v1.sample.metadata import SamplingMetadata
9
from vllm.v1.sample.ops.bad_words import apply_bad_words
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
                                          apply_min_token_penalties)
12
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
13
14
15
16
17
18

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):

19
20
21
22
    def __init__(self):
        super().__init__()
        self.topk_topp_sampler = TopKTopPSampler()

23
24
25
26
27
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
28
29
30
31
32
33
34
35
36
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        # TODO(rob): provide option for logprobs post sampling.
        # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
            raw_logprobs = self.compute_logprobs(logits)
37

38
39
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
40
41
        # Apply allowed token ids.
        logits = self.apply_allowed_token_ids(logits, sampling_metadata)
42
43
        # Apply bad words exclusion.
        logits = self.apply_bad_words(logits, sampling_metadata)
44
45
        # Apply logits bias.
        logits = self.apply_logits_bias(logits, sampling_metadata)
46
47
48
49
        # Apply penalties (e.g., min_tokens, freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata)
        # Sample the next token.
        sampled = self.sample(logits, sampling_metadata)
50
51
52
53
54
55

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
        logprobs_tensors = None if num_logprobs is None else \
            self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)

56
57
58
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

59
        # These are GPU tensors.
60
        sampler_output = SamplerOutput(
61
62
63
64
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
65
            logprobs_tensors=logprobs_tensors,
66
67
68
69
70
71
72
73
74
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
75
        return logits.div_(temp.unsqueeze(dim=1))
76

77
78
79
80
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
81
82
83
84
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
85
86
        assert not (sampling_metadata.all_greedy
                    and sampling_metadata.all_random)
87
88
89
90
91
92
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
                return greedy_sampled
93

94
95
        assert sampling_metadata.temperature is not None

96
97
98
99
        # Apply temperature.
        logits = self.apply_temperature(logits, sampling_metadata.temperature)

        # Apply min_p.
100
        if sampling_metadata.min_p is not None:
101
102
103
            logits = self.apply_min_p(logits, sampling_metadata.min_p)

        # Apply top_k and/or top_p.
104
        random_sampled = self.topk_topp_sampler(
105
            logits,
106
            sampling_metadata.generators,
107
108
109
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
110

111
        if greedy_sampled is None:
112
            return random_sampled
113
114
115
116
117

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
118
            out=greedy_sampled,  # Reuse tensor
119
120
121
        )
        return sampled

122
123
124
125
    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
126
        self,
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
          logits: (num tokens) x (vocab) tensor
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
        # Find the topK values.
        topk_logprobs, topk_indices = torch.topk(logprobs,
                                                 num_logprobs,
                                                 dim=-1)

        # Get with the logprob of the prompt or sampled token.
154
        token_ids = token_ids.unsqueeze(-1).to(torch.long)
155
156
157
158
159
160
161
162
163
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
        token_ranks = (logprobs >= token_logprobs).sum(-1)

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

164
        # Use int32 to reduce the tensor size.
165
166
167
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
168

169
170
171
172
173
    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
174
175
176
177
        if sampling_metadata.min_tokens:
            apply_min_token_penalties(logits,
                                      sampling_metadata.output_token_ids,
                                      sampling_metadata.min_tokens)
178
179
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
180
            logits = apply_all_penalties(
181
182
                logits,
                sampling_metadata.prompt_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
185
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
186
187
                sampling_metadata.output_token_ids,
            )
188
        return logits
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    def apply_min_p(
        self,
        logits: torch.Tensor,
        min_p: torch.Tensor,
    ) -> torch.Tensor:
        """
        Filters logits using adaptive probability thresholding.
        """
        # Convert logits to probability distribution
        probability_values = torch.nn.functional.softmax(logits, dim=-1)
        # Calculate maximum probabilities per sequence
        max_probabilities = torch.amax(probability_values,
                                       dim=-1,
                                       keepdim=True)
        # Reshape min_p for broadcasting
        adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
        # Identify valid tokens using threshold comparison
        valid_token_mask = probability_values >= adjusted_min_p
        # Apply mask using boolean indexing
        logits[~valid_token_mask] = -float('inf')
        return logits

212
213
214
215
216
217
218
219
220
221
222
223
224
    def apply_logits_bias(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        # TODO(houseroad): this implementation is extremely inefficient.
        # One idea is implement this as a PyTorch C++ op, and we may
        # even optimize the logit_bias layout.
        for i, logit_bias in enumerate(sampling_metadata.logit_bias):
            if logit_bias:
                for token_id, bias in logit_bias.items():
                    logits[i, token_id] += bias
        return logits
225
226
227
228
229
230
231
232
233
234

    def apply_allowed_token_ids(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.allowed_token_ids_mask is not None:
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
                                float("-inf"))
        return logits
235
236
237
238
239
240
241
242
243
244
245
246
247

    def apply_bad_words(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.bad_words_token_ids:
            apply_bad_words(
                logits,
                sampling_metadata.bad_words_token_ids,
                sampling_metadata.output_token_ids,
            )
        return logits