elementwise.py 11.3 KB
Newer Older
1
2
from dataclasses import dataclass
from typing import Any, Optional
3
4

import torch
5
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
6
7
8
9
10
11
12
13
14


# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
# Kudos to @yzh119
def rmsnorm(
    input: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
    out: Optional[torch.Tensor] = None,
15
    enable_pdl: Optional[bool] = None,
16
) -> torch.Tensor:
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    r"""Root mean square normalization.

    ``out[i] = (input[i] / RMS(input)) * weight[i]``

    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size).
    weight: torch.Tensor
        Weight tensor, shape (hidden_size,).
    eps: float
        Epsilon for numerical stability.
    out: Optional[torch.Tensor]
        The output tensor, if specified, the kernel will update this tensor inplace.
31
    enable_pdl: Optional[bool]
32
33
        Whether to enable `programmatic dependent launch
        <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
34
        If None, will be automatically enabled on Hopper architecture.
35
36
37
38
39
40

    Returns
    -------
    output: torch.Tensor
        Normalized tensor, shape (batch_size, hidden_size).
    """
41
42
    if out is None:
        out = torch.empty_like(input)
43
44
    if enable_pdl is None:
        enable_pdl = is_hopper_arch()
45
    torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
46
47
48
49
    return out


def fused_add_rmsnorm(
50
51
52
53
    input: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
54
    enable_pdl: Optional[bool] = None,
55
) -> None:
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    r"""Fused add root mean square normalization.

    Step 1:
    ``residual[i] += input[i]``

    Step 2:
    ``input[i] = (residual[i] / RMS(residual)) * weight[i]``

    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size).
    residual: torch.Tensor
        Residual tensor, shape (batch_size, hidden_size).
    weight: torch.Tensor
        Weight tensor, shape (hidden_size,).
    eps: float
        Epsilon for numerical stability.
74
    enable_pdl: Optional[bool]
75
76
        Whether to enable `programmatic dependent launch
        <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
77
        If None, will be automatically enabled on Hopper architecture.
78
    """
79
80
    if enable_pdl is None:
        enable_pdl = is_hopper_arch()
81
82
83
    torch.ops.sgl_kernel.fused_add_rmsnorm.default(
        input, residual, weight, eps, enable_pdl
    )
84
85
86
87
88
89
90


def gemma_rmsnorm(
    input: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
    out: Optional[torch.Tensor] = None,
91
    enable_pdl: Optional[bool] = None,
92
) -> torch.Tensor:
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    r"""Gemma-style root mean square normalization.

    ``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``

    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size).
    weight: torch.Tensor
        Weight tensor, shape (hidden_size,).
    eps: float
        Epsilon for numerical stability.
    out: Optional[torch.Tensor]
        The output tensor, if specified, the kernel will update this tensor inplace.
107
    enable_pdl: Optional[bool]
108
109
        Whether to enable `programmatic dependent launch
        <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
110
        If None, will be automatically enabled on Hopper architecture.
111
112
113
114
115
116

    Returns
    -------
    output: torch.Tensor
        Gemma Normalized tensor, shape (batch_size, hidden_size).
    """
117
118
    if out is None:
        out = torch.empty_like(input)
119
120
    if enable_pdl is None:
        enable_pdl = is_hopper_arch()
121
    torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
122
123
124
125
    return out


def gemma_fused_add_rmsnorm(
126
127
128
129
    input: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
130
    enable_pdl: Optional[bool] = None,
131
) -> None:
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    r"""Gemma-style fused add root mean square normalization.

    Step 1:
    ``residual[i] += input[i]``

    Step 2:
    ``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``

    Parameters
    ----------
    input: torch.Tensor
        Input tensor, shape (batch_size, hidden_size).
    residual: torch.Tensor
        Residual tensor, shape (batch_size, hidden_size).
    weight: torch.Tensor
        Weight tensor, shape (hidden_size,).
    eps: float
        Epsilon for numerical stability.
150
    enable_pdl: Optional[bool]
151
152
        Whether to enable `programmatic dependent launch
        <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
153
        If None, will be automatically enabled on Hopper architecture.
154
    """
155
156
    if enable_pdl is None:
        enable_pdl = is_hopper_arch()
157
    torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
158
        input, residual, weight, eps, enable_pdl
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    )


def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None:
    assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}"
    assert (
        input.shape[:-1] == output.shape[:-1]
    ), f"{input.shape[:-1]} != {output.shape[:-1]}"
    assert (
        input.shape[-1] == 2 * output.shape[-1]
    ), f"{input.shape[-1]} != {2 * output.shape[-1]}"


def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
    if input.shape[-1] * input.dtype.itemsize % 16 != 0:
        raise ValueError("The pointers must be multiple of 16 bytes.")
    if out is not None:
        _check_shape(input, out)
    else:
        out = torch.empty(
            input.shape[:-1] + (input.shape[-1] // 2,),
            device=input.device,
            dtype=input.dtype,
        )
183
    torch.ops.sgl_kernel.silu_and_mul.default(out, input)
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    return out


def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
    if input.shape[-1] * input.dtype.itemsize % 16 != 0:
        raise ValueError("The pointers must be multiple of 16 bytes.")
    if out is not None:
        _check_shape(input, out)
    else:
        out = torch.empty(
            input.shape[:-1] + (input.shape[-1] // 2,),
            device=input.device,
            dtype=input.dtype,
        )
198
    torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    return out


def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
    if input.shape[-1] * input.dtype.itemsize % 16 != 0:
        raise ValueError("The pointers must be multiple of 16 bytes.")
    if out is not None:
        _check_shape(input, out)
    else:
        out = torch.empty(
            input.shape[:-1] + (input.shape[-1] // 2,),
            device=input.device,
            dtype=input.dtype,
        )
213
    torch.ops.sgl_kernel.gelu_and_mul.default(out, input)
214
215
216
    return out


217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
if torch.version.hip is not None:

    def gelu_quick(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
        """
        Quick-GELU:  y = x * sigmoid(1.702 * x)

        The CUDA/HIP kernel uses 128-bit (16-byte) vector loads & stores,
        so the last-dimension byte length must be a multiple of 16 bytes.
        """
        if input.shape[-1] * input.dtype.itemsize % 16 != 0:
            raise ValueError(
                f"The last dimension ({input.shape[-1]}) x itemsize "
                f"({input.dtype.itemsize}) must be a multiple of 16 bytes."
            )

        if out is not None:
            assert input.shape == out.shape, f"{input.shape} != {out.shape}"
        else:
            out = torch.empty_like(input)

        torch.ops.sgl_kernel.gelu_quick(out, input)
        return out


241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@dataclass
class FusedSetKVBufferArg:
    """
    value : Optional[torch.Tensor]
        Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
    k_buffer : Optional[torch.Tensor]
        Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
    v_buffer : Optional[torch.Tensor]
        Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
    k_scale : Optional[float]
        Scale factor for keys.
    v_scale : Optional[float]
        Scale factor for values.
    cache_loc : Optional[torch.Tensor]
        Cache location tensor, used for indexing kv cache.
    """

    value: torch.Tensor
    k_buffer: torch.Tensor
    v_buffer: torch.Tensor
    k_scale: Optional[float]
    v_scale: Optional[float]
    cache_loc: torch.Tensor


266
267
268
269
270
271
272
def apply_rope_with_cos_sin_cache_inplace(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool = True,
273
    fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
) -> None:
    r"""
    Apply rotary embedding to keys and queries with precomputed cos/sin values.
    This is designed to be compatible with the SGL/vLLM implementation.
    The result is inplace applied to the input tensors.

    Parameters
    ----------
    positions : torch.Tensor
        Position indices, shape: ``(nnz)``.
    query : torch.Tensor
        Query tensor, shape: ``(nnz, num_q_heads * head_size)``.
    key : torch.Tensor
        Key tensor, shape: ``(nnz, num_k_heads * head_size)``.
    cos_sin_cache : torch.Tensor
        Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``.
        Cosine is the first half and Sine is the second half on rotary_dim.
    is_neox : bool
        Whether to use Neox style RoPE, default: ``True``.

        * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e.,
295
          we rotate the first half dimensions ``([..., :head_dim//2])`` and the second half
296
297
298
299
          dimensions ``([..., head_dim//2:])``.

        * If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
          we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
300
301
302
    fused_set_kv_buffer_arg : FusedSetKVBufferArg
        Fuse the set-kv-buffer operation into this kernel

303
304
305
306
307
308
309
    Note
    ----
    The rotary dimension is determined by the cosine cache and sine cache.
    """
    if cos_sin_cache.dtype != torch.float32:
        raise ValueError("cos_sin_cache should be float32")

310
311
312
313
314
315
316
317
    if (a := fused_set_kv_buffer_arg) is not None:
        assert a.k_scale is None, "k_scale is not yet supported"
        assert a.v_scale is None, "v_scale is not yet supported"
        assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"

    def _view_3d(x):
        return x.view(x.shape[0], -1, head_size)

318
    torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
319
320
321
322
        _view_3d(query),
        _view_3d(key),
        _view_3d(query),
        _view_3d(key),
323
324
325
326
        cos_sin_cache,
        positions.long(),
        (not is_neox),
        get_cuda_stream(),
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
        (
            _view_3d(fused_set_kv_buffer_arg.value)
            if fused_set_kv_buffer_arg is not None
            else None
        ),
        (
            _view_3d(fused_set_kv_buffer_arg.k_buffer)
            if fused_set_kv_buffer_arg is not None
            else None
        ),
        (
            _view_3d(fused_set_kv_buffer_arg.v_buffer)
            if fused_set_kv_buffer_arg is not None
            else None
        ),
        (
            fused_set_kv_buffer_arg.cache_loc
            if fused_set_kv_buffer_arg is not None
            else None
        ),
347
    )