gumbel.py 5.27 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch

from vllm.triton_utils import tl, triton


8
9
10
11
@triton.jit
def _temperature_kernel(
    logits_ptr,
    logits_stride,
12
    expanded_idx_mapping_ptr,
13
14
15
16
    temperature_ptr,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
):
17
18
    token_idx = tl.program_id(0)
    req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
19
20
21
22
23
24
25
26
27
    temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
    if temperature == 0.0 or temperature == 1.0:
        # Early return to avoid loading logits.
        return

    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size

28
    logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
29
30
    logits = logits.to(tl.float32)
    logits = logits / temperature
31
    tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
32
33
34
35


def apply_temperature(
    logits: torch.Tensor,
36
    expanded_idx_mapping: torch.Tensor,
37
38
    temperature: torch.Tensor,
) -> None:
39
    num_tokens, vocab_size = logits.shape
40
41
    BLOCK_SIZE = 8192
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
42
    _temperature_kernel[(num_tokens, num_blocks)](
43
44
        logits,
        logits.stride(0),
45
        expanded_idx_mapping,
46
47
48
49
50
51
        temperature,
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
    )


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@triton.jit
def tl_rand64(seed, offset, includes_zero: tl.constexpr):
    lo, hi, _, _ = tl.randint4x(seed, offset)
    lo = lo.to(tl.uint32, bitcast=True).to(tl.uint64)
    hi = hi.to(tl.uint32, bitcast=True).to(tl.uint64)
    r = (hi << 32) | lo

    # 1 / 2**64
    scale = 5.421010862427522170037e-20
    u = r.to(tl.float64) * scale
    if not includes_zero:
        u = tl.maximum(u, 2.2250738585072014e-308)  # float64 tiny
    return u


67
68
69
70
71
72
@triton.jit
def _gumbel_sample_kernel(
    local_argmax_ptr,
    local_argmax_stride,
    local_max_ptr,
    local_max_stride,
73
74
    processed_logits_ptr,
    processed_logits_stride,
75
76
    logits_ptr,
    logits_stride,
77
    expanded_idx_mapping_ptr,
78
79
80
81
82
83
84
    seeds_ptr,
    pos_ptr,
    temp_ptr,
    vocab_size,
    BLOCK_SIZE: tl.constexpr,
    APPLY_TEMPERATURE: tl.constexpr,
):
85
86
    token_idx = tl.program_id(0)
    req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
87

88
89
90
91
    block_idx = tl.program_id(1)
    block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = block < vocab_size
    logits = tl.load(
92
        logits_ptr + token_idx * logits_stride + block,
93
94
95
96
97
        mask=mask,
        other=float("-inf"),
    )
    logits = logits.to(tl.float32)

98
    temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
99
    if temp != 0.0 and APPLY_TEMPERATURE:
100
101
102
103
104
105
106
107
108
109
110
111
112
        # Apply temperature.
        # NOTE(woosuk): Match the behavior of _temperature_kernel.
        # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
        logits = logits / temp

    # Store the temperature-applied logits.
    if processed_logits_ptr is not None:
        tl.store(
            processed_logits_ptr + req_state_idx * processed_logits_stride + block,
            logits,
            mask=mask,
        )

113
    logits = logits.to(tl.float64)
114
115
    if temp != 0.0:
        # Calculate the seed for gumbel noise.
116
        seed = tl.load(seeds_ptr + req_state_idx)
117
        pos = tl.load(pos_ptr + token_idx)
118
119
        gumbel_seed = tl.randint(seed, pos)

120
121
122
        # tl.rand returns fp32, so build a true fp64 uniform from 64 random
        # bits before applying the double-log transform.
        u = tl_rand64(gumbel_seed, block, includes_zero=False)
123
        gumbel_noise = -tl.log(-tl.log(u))
124
125
126
127

        # Apply gumbel noise.
        logits = tl.where(mask, logits + gumbel_noise, float("-inf"))

128
    value, idx = tl.max(logits, axis=0, return_indices=True)
129
    token_id = block_idx * BLOCK_SIZE + idx
130
131
    tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
    tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
132
133
134


def gumbel_sample(
135
136
    logits: torch.Tensor,  # [num_tokens, vocab_size]
    expanded_idx_mapping: torch.Tensor,  # [num_tokens]
137
138
    temperature: torch.Tensor,  # [max_num_reqs]
    seed: torch.Tensor,  # [max_num_reqs]
139
    pos: torch.Tensor,  # [num_tokens]
140
    apply_temperature: bool,
141
    processed_logits_out: torch.Tensor | None = None,  # [num_reqs, vocab_size]
142
) -> torch.Tensor:
143
    num_tokens, vocab_size = logits.shape
144
145
    BLOCK_SIZE = 1024
    num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
146
    local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
147
    local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float64)
148
    _gumbel_sample_kernel[(num_tokens, num_blocks)](
149
150
151
152
        local_argmax,
        local_argmax.stride(0),
        local_max,
        local_max.stride(0),
153
154
        processed_logits_out,
        processed_logits_out.stride(0) if processed_logits_out is not None else 0,
155
156
        logits,
        logits.stride(0),
157
        expanded_idx_mapping,
158
159
160
161
162
163
164
165
166
167
168
        seed,
        pos,
        temperature,
        vocab_size,
        BLOCK_SIZE=BLOCK_SIZE,
        APPLY_TEMPERATURE=apply_temperature,
    )
    # NOTE(woosuk): Use int64 for later indexing.
    max_block_idx = local_max.argmax(dim=-1, keepdim=True)
    sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
    return sampled