penalties.py 4.9 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.triton_utils import tl, triton
6
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
7
8
9


@triton.jit
10
def _penalties_and_temperature_kernel(
11
12
13
14
15
    logits_ptr,
    logits_stride,
    repetition_penalty_ptr,
    frequency_penalty_ptr,
    presence_penalty_ptr,
16
    temperature_ptr,
17
18
19
20
21
22
23
24
25
26
27
28
    idx_mapping_ptr,
    prompt_bin_counts_ptr,
    prompt_bin_counts_stride,
    output_bin_counts_ptr,
    output_bin_counts_stride,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
):
    batch_idx = tl.program_id(0)
    rep_penalty = tl.load(repetition_penalty_ptr + batch_idx)
    freq_penalty = tl.load(frequency_penalty_ptr + batch_idx)
    pres_penalty = tl.load(presence_penalty_ptr + batch_idx)
29
30
    temperature = tl.load(temperature_ptr + batch_idx)
    temperature = tl.where(temperature == 0.0, 1.0, temperature)
31
32
33
34

    use_rep_penalty = rep_penalty != 1.0
    use_freq_penalty = freq_penalty != 0.0
    use_pres_penalty = pres_penalty != 0.0
35
36
37
38
    use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
    use_temperature = temperature != 1.0
    if not (use_penalty or use_temperature):
        # Early return to avoid loading logits.
39
40
41
42
43
44
45
46
        return

    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size
    logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
    logits = logits.to(tl.float32)

47
48
49
50
    if use_penalty:
        req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
        output_bin_counts = tl.load(
            output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
51
52
            mask=mask,
        )
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        output_bin_mask = output_bin_counts > 0

        # Apply repetition penalties.
        if use_rep_penalty:
            prompt_bin_counts = tl.load(
                prompt_bin_counts_ptr
                + req_state_idx * prompt_bin_counts_stride
                + block,
                mask=mask,
            )
            prompt_bin_mask = prompt_bin_counts > 0
            # If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
            scale = tl.where(prompt_bin_mask | output_bin_mask, rep_penalty, 1.0)
            # If logits are positive, divide by penalty, otherwise multiply by penalty.
            logits *= tl.where(logits > 0, 1.0 / scale, scale)

        # Apply frequency penalties.
        logits -= freq_penalty * output_bin_counts
        # Apply presence penalties.
        logits -= pres_penalty * output_bin_mask

    # Apply temperature.
    logits = logits / temperature
76
77
78
79
80

    # Store back to logits.
    tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)


81
82
83
84
def apply_penalties_and_temperature(
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> None:
85
86
87
    num_reqs, vocab_size = logits.shape
    BLOCK_SIZE = 8192
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
88
    _penalties_and_temperature_kernel[(num_reqs, num_blocks)](
89
90
91
92
93
        logits,
        logits.stride(0),
        sampling_metadata.repetition_penalty,
        sampling_metadata.frequency_penalty,
        sampling_metadata.presence_penalty,
94
        sampling_metadata.temperature,
95
96
97
98
99
100
101
102
        sampling_metadata.idx_mapping,
        sampling_metadata.prompt_bin_counts,
        sampling_metadata.prompt_bin_counts.stride(0),
        sampling_metadata.output_bin_counts,
        sampling_metadata.output_bin_counts.stride(0),
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
    )
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148


@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
    prefill_token_ids_ptr,
    prefill_len,
    prompt_len,
    prompt_bin_counts_ptr,
    output_bin_counts_ptr,
    BLOCK_SIZE: tl.constexpr,
):
    block_idx = tl.program_id(0)
    if block_idx * BLOCK_SIZE >= prefill_len:
        return

    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    if block_idx * BLOCK_SIZE < prompt_len:
        mask = block < prompt_len
        prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
        tl.atomic_add(prompt_bin_counts_ptr + prefill_tokens, 1, mask=mask)
    if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
        mask = block < prefill_len
        mask &= block >= prompt_len
        prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
        tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)


def bincount(
    prefill_token_ids: torch.Tensor,
    prefill_len: int,
    prompt_len: int,
    prompt_bin_counts: torch.Tensor,
    output_bin_counts: torch.Tensor,
) -> None:
    prompt_bin_counts.zero_()
    output_bin_counts.zero_()
    BLOCK_SIZE = 1024
    num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
    _bincount_kernel[(num_blocks,)](
        prefill_token_ids,
        prefill_len,
        prompt_len,
        prompt_bin_counts,
        output_bin_counts,
        BLOCK_SIZE=BLOCK_SIZE,
    )