layernorm.py 12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14
15
"""Fused operators for normalization layers."""

16
import logging
17
18
19
20
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
21
from packaging.version import Version
22

Lianmin Zheng's avatar
Lianmin Zheng committed
23
from sglang.srt.custom_op import CustomOp
24
25
26
27
28
from sglang.srt.utils import (
    cpu_has_amx_support,
    get_bool_env_var,
    is_cpu,
    is_cuda,
29
    is_flashinfer_available,
30
31
    is_hip,
    is_npu,
32
    is_xpu,
33
    supports_custom_op,
34
)
35

36
_is_cuda = is_cuda()
37
_is_flashinfer_available = is_flashinfer_available()
38
_is_hip = is_hip()
39
_is_npu = is_npu()
40
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
41
42
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
43
_is_xpu = is_xpu()
Yineng Zhang's avatar
Yineng Zhang committed
44

45
if _is_cuda or _is_xpu:
46
47
48
49
50
51
52
53
54
    # if _is_flashinfer_available:
    #     from flashinfer.norm import fused_add_rmsnorm
    # else:
    from sgl_kernel import (
        fused_add_rmsnorm,
        gemma_fused_add_rmsnorm,
        gemma_rmsnorm,
        rmsnorm,
    )
55
56
57
58
if _use_aiter:
    from aiter import rmsnorm2d_fwd as rms_norm
    from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip:
59
    import vllm
60
    from vllm._custom_ops import fused_add_rms_norm, rms_norm
61

62
63
    _vllm_version = Version(vllm.__version__)

64
logger = logging.getLogger(__name__)
65

66
if _is_npu:
ll819214's avatar
ll819214 committed
67
68
    import torch_npu

69

70
class RMSNorm(CustomOp):
71
72
73
74
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
RunningLeon's avatar
RunningLeon committed
75
        var_hidden_size: Optional[int] = None,
76
77
78
79
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps
RunningLeon's avatar
RunningLeon committed
80
81
82
83
        self.hidden_size = hidden_size
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
84
85
        if _use_aiter:
            self._forward_method = self.forward_aiter
86
        if get_bool_env_var("SGLANG_ENABLE_DETERMINISTIC_INFERENCE"):
87
            self._forward_method = self.forward_native
88

89
    def forward_cuda(
90
91
92
93
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
RunningLeon's avatar
RunningLeon committed
94
95
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
96
        if residual is not None:
97
            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
98
99
100
101
            return x, residual
        out = rmsnorm(x, self.weight.data, self.variance_epsilon)
        return out

ll819214's avatar
ll819214 committed
102
103
104
105
106
107
108
109
110
111
112
113
    def forward_npu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
            out, _, residual_out = torch_npu.npu_add_rms_norm(
                residual, x, self.weight.data, self.variance_epsilon
            )
            return out, residual_out
        return torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0]

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    def forward_aiter(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
            residual_out = torch.empty_like(x)
            output = torch.empty_like(x)
            fused_add_rms_norm(
                output,
                x,
                residual,
                residual_out,
                self.weight.data,
                self.variance_epsilon,
            )
            return output, residual_out
        return rms_norm(x, self.weight.data, self.variance_epsilon)

133
134
135
136
137
138
    def forward_hip(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if not x.is_contiguous():
139
            # NOTE: Remove this if aiter kernel supports discontinuous input
140
141
            x = x.contiguous()
        if residual is not None:
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            if _vllm_version < Version("0.9"):
                fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
                return x, residual
            else:
                residual_out = torch.empty_like(x)
                output = torch.empty_like(x)
                fused_add_rms_norm(
                    output,
                    x,
                    residual_out,
                    residual,
                    self.weight.data,
                    self.variance_epsilon,
                )
                return output, residual_out
157
158
159
160
        out = torch.empty_like(x)
        rms_norm(out, x, self.weight.data, self.variance_epsilon)
        return out

161
162
163
164
165
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
166
167
        if not x.is_contiguous():
            x = x.contiguous()
168
169
170
171
172
173
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

RunningLeon's avatar
RunningLeon committed
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError(
                "Expected hidden_size to be "
                f"{self.hidden_size}, but found: {hidden_size}"
            )

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}"
                )

            x_var = x[..., : self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)
193
        x = x * torch.rsqrt(variance + self.variance_epsilon)
194
        x = (x * self.weight).to(orig_dtype)
195
196
197
198
        if residual is None:
            return x
        else:
            return x, residual
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def forward_cpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if _is_cpu_amx_available:
            if residual is not None:
                torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
                    x, residual, self.weight.data, self.variance_epsilon
                )
                return x, residual
            return torch.ops.sgl_kernel.rmsnorm_cpu(
                x, self.weight.data, self.variance_epsilon
            )
        else:
            return self.forward_native(x, residual)

Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
217
218
219
220
221
222
223
224
225
226
227
228
229
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        if residual is not None:
            fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
            return x, residual
        out = rmsnorm(x, self.weight.data, self.variance_epsilon)
        return out

230
231
232
233
234
235
236
237
238
239
240
    def forward_with_allreduce_fusion(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward method with allreduce fusion, prioritizing flashinfer fused operations
        """
        if residual is not None:
            from sglang.srt.distributed import get_tensor_model_parallel_world_size
            from sglang.srt.layers.flashinfer_comm_fusion import (
241
                flashinfer_allreduce_residual_rmsnorm,
242
243
            )

244
245
246
247
248
249
            fused_op = (
                torch.ops.sglang.flashinfer_allreduce_residual_rmsnorm
                if supports_custom_op()
                else flashinfer_allreduce_residual_rmsnorm
            )

250
            if get_tensor_model_parallel_world_size() > 1:
251
                fused_result = fused_op(
252
253
254
255
256
257
258
259
260
261
                    input_tensor=x,
                    residual=residual,
                    weight=self.weight,
                    eps=self.variance_epsilon,
                )
                if fused_result[0] is not None:
                    return fused_result

        return self.forward(x, residual)

262
263
264
265
266
267
268
269
270
271
272

class GemmaRMSNorm(CustomOp):
    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

273
274
275
        # Re-dispatch
        if _is_hip:
            self._forward_method = self.forward_native
276

Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
277
278
279
280
281
282
283
284
285
286
287
288
289
    def _forward_impl(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
            gemma_fused_add_rmsnorm(
                x, residual, self.weight.data, self.variance_epsilon
            )
            return x, residual
        out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
        return out

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x * (1.0 + self.weight.float())
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
312
        return self._forward_impl(x, residual)
313

314
315
316
317
318
319
320
321
    def forward_npu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
            x = x + residual
            residual = x
322

323
        x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
324
325
        return x if residual is None else (x, residual)

Huaiyu, Zheng's avatar
Huaiyu, Zheng committed
326
327
328
329
330
331
332
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        return self._forward_impl(x, residual)

333
334

class Gemma3RMSNorm(CustomOp):
335
336
337
338
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))
339
        # Re-dispatch
340
341
342
343

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

344
    def forward_native(self, x):
345
346
347
348
349
350
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

351
352
353
354
355
356
357
    def forward_cuda(self, x):
        return self.forward_native(x)

    def forward_npu(self, x):
        output, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.eps)
        return output

358
359
360
361
    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"


362
363
364
if not (
    _is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
):
365
    logger.info(
366
        "sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
367
    )
368
    from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm  # noqa: F401