penalties.py 10.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import numpy as np
4
5
import torch

6
from vllm.sampling_params import SamplingParams
7
from vllm.triton_utils import tl, triton
8
from vllm.utils.math_utils import cdiv
9
from vllm.utils.torch_utils import async_tensor_h2d
10
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
11
from vllm.v1.worker.gpu.states import RequestState
12
13
14


class PenaltiesState:
15
16
17
18
19
20
    def __init__(self, req_states: RequestState):
        self.req_states = req_states

        max_num_reqs = req_states.max_num_reqs
        self.vocab_size = req_states.vocab_size
        self.device = req_states.device
21
22
23
24

        self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
        self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
        self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
25
        self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
26
27
28
29
30
31
32

        # Initialize repetition penalty manually because 0 is an invalid value for it.
        self.repetition_penalty.np.fill(1.0)
        self.repetition_penalty.copy_to_uva()

        # Statistics for penalties.
        self.prompt_bin_mask = torch.zeros(
33
            max_num_reqs,
34
35
36
37
38
39
40
            cdiv(self.vocab_size, 32),
            dtype=torch.int32,
            device=self.device,
        )
        # TODO(woosuk): This tensor is rarely used but can be very large, taking up
        # GBs of GPU memory. Optimize the memory usage.
        self.output_bin_counts = torch.zeros(
41
            max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
42
43
        )

44
        self._new_penalties_reqs: list[int] = []
45
46
47
48
49

    def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
        self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
        self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
        self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
50
51
52
53

        do_penalty = use_penalty(sampling_params)
        self.use_penalty[req_idx] = do_penalty
        if do_penalty:
54
55
56
57
58
59
60
61
62
63
            self._new_penalties_reqs.append(req_idx)

    def apply_staged_writes(self) -> None:
        if self._new_penalties_reqs:
            idx_mapping = async_tensor_h2d(
                self._new_penalties_reqs,
                dtype=torch.int32,
                target_device=self.device,
                pin_memory=True,
            )
64

65
66
            prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
            max_prefill_len = int(prefill_lens.max())
67
            bincount(
68
69
70
71
72
73
74
                idx_mapping,
                self.req_states.all_token_ids.gpu,
                self.req_states.prompt_len.gpu,
                self.req_states.prefill_len.gpu,
                self.prompt_bin_mask,
                self.output_bin_counts,
                max_prefill_len,
75
            )
76
            self._new_penalties_reqs.clear()
77
78
79
80
81

        self.repetition_penalty.copy_to_uva()
        self.frequency_penalty.copy_to_uva()
        self.presence_penalty.copy_to_uva()

82
83
84
    def apply_penalties(
        self,
        logits: torch.Tensor,
85
        expanded_idx_mapping: torch.Tensor,
86
        idx_mapping_np: np.ndarray,
87
88
89
        input_ids: torch.Tensor,
        expanded_local_pos: torch.Tensor,
        num_speculative_tokens: int,
90
91
92
93
94
    ) -> None:
        if not np.any(self.use_penalty[idx_mapping_np]):
            # No request uses penalties. Skip the kernel launch.
            return

95
        apply_penalties(
96
            logits,
97
            expanded_idx_mapping,
98
99
            input_ids,
            expanded_local_pos,
100
101
102
103
104
            self.repetition_penalty.gpu,
            self.frequency_penalty.gpu,
            self.presence_penalty.gpu,
            self.prompt_bin_mask,
            self.output_bin_counts,
105
            num_speculative_tokens,
106
        )
107
108
109


@triton.jit
110
def _penalties_kernel(
111
112
    logits_ptr,
    logits_stride,
113
    expanded_idx_mapping_ptr,
114
115
    token_ids_ptr,
    expanded_local_pos_ptr,
116
117
118
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
119
120
    prompt_bin_mask_ptr,
    prompt_bin_mask_stride,
121
122
123
124
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
125
    MAX_SPEC_LEN: tl.constexpr,
126
):
127
    token_idx = tl.program_id(0)
128
    req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
129
130
131
    rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
    freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
    pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
132
133
134
135

    use_rep_penalty = rep_penalty != 1.0
    use_freq_penalty = freq_penalty != 0.0
    use_pres_penalty = pres_penalty != 0.0
136
    use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
137
    if not use_penalty:
138
        # Early return to avoid loading logits.
139
140
141
142
143
        return

    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size
144
    logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
145
146
    logits = logits.to(tl.float32)

147
    base_output_counts = tl.load(
148
149
        output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
        mask=mask,
150
        other=0,
151
    )
152
153
154
155
156
157
158
159
160
161
162
163
164

    # Compute cumulative draft_counts from previous positions in this request
    pos = tl.load(expanded_local_pos_ptr + token_idx)
    start_idx = token_idx - pos
    draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
    for prev_pos in tl.static_range(MAX_SPEC_LEN):
        if prev_pos < pos:
            prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
            token_match = block == prev_token
            draft_counts = draft_counts + token_match.to(tl.int32)

    # Total counts = base output counts + cumulative draft counts
    output_bin_counts = base_output_counts + draft_counts
165
166
167
168
169
170
171
172
    output_bin_mask = output_bin_counts > 0

    # Apply repetition penalties.
    if use_rep_penalty:
        packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
        packed_mask = tl.load(
            prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
            mask=packed_block < tl.cdiv(vocab_size, 32),
173
            other=0,
174
        )
175
176
177
178
179
180
181
182
183
184
185
186
187
        prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
        prompt_bin_mask = prompt_bin_mask.to(tl.int1)
        prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)

        # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
        scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
        # If logits are positive, divide by penalty, otherwise multiply by penalty.
        logits *= tl.where(logits > 0, 1.0 / scale, scale)

    # Apply frequency penalties.
    logits -= freq_penalty * output_bin_counts
    # Apply presence penalties.
    logits -= pres_penalty * output_bin_mask
188
    # Store back to logits.
189
    tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
190
191


192
def apply_penalties(
193
    logits: torch.Tensor,
194
    expanded_idx_mapping: torch.Tensor,
195
196
    token_ids: torch.Tensor,
    expanded_local_pos: torch.Tensor,
197
198
199
200
201
    repetition_penalty: torch.Tensor,
    frequency_penalty: torch.Tensor,
    presence_penalty: torch.Tensor,
    prompt_bin_mask: torch.Tensor,
    output_bin_counts: torch.Tensor,
202
    num_speculative_tokens: int,
203
) -> None:
204
    num_tokens, vocab_size = logits.shape
205
206
    BLOCK_SIZE = 8192
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
207
    _penalties_kernel[(num_tokens, num_blocks)](
208
209
        logits,
        logits.stride(0),
210
        expanded_idx_mapping,
211
212
        token_ids,
        expanded_local_pos,
213
214
215
216
217
218
219
        repetition_penalty,
        frequency_penalty,
        presence_penalty,
        prompt_bin_mask,
        prompt_bin_mask.stride(0),
        output_bin_counts,
        output_bin_counts.stride(0),
220
221
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
222
        MAX_SPEC_LEN=num_speculative_tokens,
223
    )
224
225


226
@triton.jit
227
def _bincount_kernel(
228
    expanded_idx_mapping_ptr,
229
    all_token_ids_ptr,
230
231
232
    all_token_ids_stride,
    prompt_len_ptr,
    prefill_len_ptr,
233
    prompt_bin_mask_ptr,
234
    prompt_bin_mask_stride,
235
    output_bin_counts_ptr,
236
    output_bin_counts_stride,
237
238
    BLOCK_SIZE: tl.constexpr,
):
239
    token_idx = tl.program_id(0)
240
    block_idx = tl.program_id(1)
241
    req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
242
243

    prefill_len = tl.load(prefill_len_ptr + req_state_idx)
244
245
246
    if block_idx * BLOCK_SIZE >= prefill_len:
        return

247
    prompt_len = tl.load(prompt_len_ptr + req_state_idx)
248
249
250
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    if block_idx * BLOCK_SIZE < prompt_len:
        mask = block < prompt_len
251
252
253
        prompt_tokens = tl.load(
            all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
        )
254
255
        idx = prompt_tokens // 32
        bit_idx = prompt_tokens % 32
256
        bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
257
258
259
260
261
262
        tl.atomic_or(
            prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
            bit,
            mask=mask,
        )

263
264
265
    if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
        mask = block < prefill_len
        mask &= block >= prompt_len
266
267
268
269
270
271
272
273
274
275
        output_tokens = tl.load(
            all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
        )
        tl.atomic_add(
            output_bin_counts_ptr
            + req_state_idx * output_bin_counts_stride
            + output_tokens,
            1,
            mask=mask,
        )
276
277
278


def bincount(
279
    expanded_idx_mapping: torch.Tensor,
280
    all_token_ids: torch.Tensor,
281
282
    prompt_len: torch.Tensor,
    prefill_len: torch.Tensor,
283
    prompt_bin_mask: torch.Tensor,
284
    output_bin_counts: torch.Tensor,
285
    max_prefill_len: int,
286
) -> None:
287
288
289
    prompt_bin_mask[expanded_idx_mapping] = 0
    output_bin_counts[expanded_idx_mapping] = 0
    num_tokens = expanded_idx_mapping.shape[0]
290
    BLOCK_SIZE = 1024
291
    num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
292
293
    _bincount_kernel[(num_tokens, num_blocks)](
        expanded_idx_mapping,
294
        all_token_ids,
295
        all_token_ids.stride(0),
296
        prompt_len,
297
        prefill_len,
298
        prompt_bin_mask,
299
        prompt_bin_mask.stride(0),
300
        output_bin_counts,
301
        output_bin_counts.stride(0),
302
303
        BLOCK_SIZE=BLOCK_SIZE,
    )
304
305
306
307
308
309
310
311


def use_penalty(sampling_params: SamplingParams) -> bool:
    return (
        sampling_params.repetition_penalty != 1.0
        or sampling_params.frequency_penalty != 0.0
        or sampling_params.presence_penalty != 0.0
    )