sampler.py 8.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
4
5
6
"""A layer that samples the next tokens from the model's outputs."""

import torch
import torch.nn as nn

7
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
8
from vllm.v1.sample.metadata import SamplingMetadata
Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
                                          apply_min_token_penalties)
11
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
12
13
14
15
16
17

_SAMPLING_EPS = 1e-5


class Sampler(nn.Module):

18
19
20
21
    def __init__(self):
        super().__init__()
        self.topk_topp_sampler = TopKTopPSampler()

22
23
24
25
26
    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
27
28
29
30
31
32
33
34
35
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        # TODO(rob): provide option for logprobs post sampling.
        # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
            raw_logprobs = self.compute_logprobs(logits)
36

37
38
        # Use float32 for the logits.
        logits = logits.to(torch.float32)
39
40
        # Apply allowed token ids.
        logits = self.apply_allowed_token_ids(logits, sampling_metadata)
41
42
        # Apply logits bias.
        logits = self.apply_logits_bias(logits, sampling_metadata)
43
44
45
46
        # Apply penalties (e.g., min_tokens, freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata)
        # Sample the next token.
        sampled = self.sample(logits, sampling_metadata)
47
48
49
50
51
52

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
        logprobs_tensors = None if num_logprobs is None else \
            self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)

53
54
55
        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

56
        # These are GPU tensors.
57
        sampler_output = SamplerOutput(
58
59
60
61
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
62
            logprobs_tensors=logprobs_tensors,
63
64
65
66
67
68
69
70
71
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
72
        return logits.div_(temp.unsqueeze(dim=1))
73

74
75
76
77
    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
78
79
80
81
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
82
83
        assert not (sampling_metadata.all_greedy
                    and sampling_metadata.all_random)
84
85
86
87
88
89
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
                return greedy_sampled
90

91
92
        assert sampling_metadata.temperature is not None

93
94
95
96
        # Apply temperature.
        logits = self.apply_temperature(logits, sampling_metadata.temperature)

        # Apply min_p.
97
        if sampling_metadata.min_p is not None:
98
99
100
            logits = self.apply_min_p(logits, sampling_metadata.min_p)

        # Apply top_k and/or top_p.
101
        random_sampled = self.topk_topp_sampler(
102
            logits,
103
            sampling_metadata.generators,
104
105
106
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )
107

108
        if greedy_sampled is None:
109
            return random_sampled
110
111
112
113
114

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
115
            out=greedy_sampled,  # Reuse tensor
116
117
118
        )
        return sampled

119
120
121
122
123
124
125
126
    def compute_probs(self, logits: torch.Tensor,
                      sampling_metadata: SamplingMetadata) -> torch.Tensor:
        if sampling_metadata.all_greedy:
            return logits
        # Apply temperature. This is an in-place op changing logits.
        logits = self.apply_temperature(logits, sampling_metadata.temperature)
        return logits.softmax(dim=-1, dtype=torch.float32)

127
128
129
130
    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
131
        self,
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
          logits: (num tokens) x (vocab) tensor
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
        # Find the topK values.
        topk_logprobs, topk_indices = torch.topk(logprobs,
                                                 num_logprobs,
                                                 dim=-1)

        # Get with the logprob of the prompt or sampled token.
        token_ids = token_ids.unsqueeze(-1)
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
        token_ranks = (logprobs >= token_logprobs).sum(-1)

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

169
        # Use int32 to reduce the tensor size.
170
171
172
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)
173

174
175
176
177
178
    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
179
180
181
182
        if sampling_metadata.min_tokens:
            apply_min_token_penalties(logits,
                                      sampling_metadata.output_token_ids,
                                      sampling_metadata.min_tokens)
183
184
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
Woosuk Kwon's avatar
Woosuk Kwon committed
185
            logits = apply_all_penalties(
186
187
                logits,
                sampling_metadata.prompt_token_ids,
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
191
192
                sampling_metadata.output_token_ids,
            )
193
        return logits
194

195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def apply_min_p(
        self,
        logits: torch.Tensor,
        min_p: torch.Tensor,
    ) -> torch.Tensor:
        """
        Filters logits using adaptive probability thresholding.
        """
        # Convert logits to probability distribution
        probability_values = torch.nn.functional.softmax(logits, dim=-1)
        # Calculate maximum probabilities per sequence
        max_probabilities = torch.amax(probability_values,
                                       dim=-1,
                                       keepdim=True)
        # Reshape min_p for broadcasting
        adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
        # Identify valid tokens using threshold comparison
        valid_token_mask = probability_values >= adjusted_min_p
        # Apply mask using boolean indexing
        logits[~valid_token_mask] = -float('inf')
        return logits

217
218
219
220
221
222
223
224
225
226
227
228
229
    def apply_logits_bias(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        # TODO(houseroad): this implementation is extremely inefficient.
        # One idea is implement this as a PyTorch C++ op, and we may
        # even optimize the logit_bias layout.
        for i, logit_bias in enumerate(sampling_metadata.logit_bias):
            if logit_bias:
                for token_id, bias in logit_bias.items():
                    logits[i, token_id] += bias
        return logits
230
231
232
233
234
235
236
237
238
239

    def apply_allowed_token_ids(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> torch.Tensor:
        if sampling_metadata.allowed_token_ids_mask is not None:
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
                                float("-inf"))
        return logits