layernorm.py 14.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Custom normalization layers."""
4

5
6
import torch
import torch.nn as nn
7
import torch.nn.functional as F
8

9
from vllm._aiter_ops import rocm_aiter_ops
10
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
11

12
13
from vllm.model_executor.layers.batch_invariant import (
    rms_norm_batch_invariant,
14
    vllm_is_batch_invariant,
15
)
16
from vllm.platforms import current_platform
17
from vllm import envs
18
19


20
21
22
def rms_norm(
    x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
23
    from vllm import _custom_ops as ops
24

25
    if vllm_is_batch_invariant():
26
        return rms_norm_batch_invariant(x, weight, variance_epsilon)
27
    out = torch.empty_like(x)
zhuwenwen's avatar
zhuwenwen committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
    if envs.VLLM_USE_OPT_OP:
        ops.rms_norm_opt(
            out,
            x,
            weight,
            variance_epsilon,
        )
    else:
        ops.rms_norm(
            out,
            x,
            weight,
            variance_epsilon,
        )
42
43
44
45
    return out


def fused_add_rms_norm(
46
47
48
49
50
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
51
    from vllm import _custom_ops as ops
52

53
    if vllm_is_batch_invariant():
54
55
56
        return rms_norm_batch_invariant(
            x + residual, weight, variance_epsilon
        ), x + residual
zhuwenwen's avatar
zhuwenwen committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    if envs.VLLM_USE_OPT_OP:
        ops.fused_add_rms_norm_opt(
            x,
            residual,
            weight,
            variance_epsilon,
        )
    else:
        ops.fused_add_rms_norm(
            x,
            residual,
            weight,
            variance_epsilon,
        )
71
72
73
    return x, residual


74
75
def poly_norm(
    x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, variance_epsilon: float
76
) -> torch.Tensor:
77
    from vllm import _custom_ops as ops
78

79
80
81
82
83
84
85
86
87
88
89
    out = torch.empty_like(x)
    ops.poly_norm(
        out,
        x,
        weight,
        bias,
        variance_epsilon,
    )
    return out


90
91
92
93
def dispatch_rocm_rmsnorm_func(
    with_fused_add: bool, dtype: torch.dtype, use_aiter: bool = False
):
    use_aiter = use_aiter and dtype in [
94
95
        torch.float16,
        torch.bfloat16,
96
97
98
    ]

    if use_aiter and with_fused_add:
99
        return rocm_aiter_ops.rms_norm2d_with_add
100
    if use_aiter:
101
        return rocm_aiter_ops.rms_norm
102

103
104
105
    # fall back to CUDA implementation
    if with_fused_add:
        return fused_add_rms_norm
106
    return rms_norm
107
108


109
@CustomOp.register("rms_norm")
110
class RMSNorm(CustomOp):
111
112
113
114
115
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
116
117
118
119
120

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
121
        var_hidden_size: int | None = None,
122
        has_weight: bool = True,
123
        dtype: torch.dtype | None = None,
124
125
    ) -> None:
        super().__init__()
126
127

        self.hidden_size = hidden_size
128
        self.variance_epsilon = eps
129
130
131
        self.variance_size_override = (
            None if var_hidden_size == hidden_size else var_hidden_size
        )
132
        weight_dtype = dtype or torch.get_default_dtype()
133
        self.has_weight = has_weight
134
        self.weight = torch.ones(hidden_size, dtype=weight_dtype)
135
136
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
137
138

        if current_platform.is_rocm():
139
            aiter_rmsnorm_enabled = rocm_aiter_ops.is_rmsnorm_enabled()
140
            self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
141
142
143
                with_fused_add=False,
                dtype=weight_dtype,
                use_aiter=aiter_rmsnorm_enabled,
144
            )
145
            self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
146
                with_fused_add=True, dtype=weight_dtype, use_aiter=aiter_rmsnorm_enabled
147
            )
148

149
150
    @staticmethod
    def forward_static(
zhuwenwen's avatar
zhuwenwen committed
151
        self,
152
        x: torch.Tensor,
153
154
155
156
        variance_epsilon: float,
        hidden_size: int,
        orig_dtype: torch.dtype,
        weight: torch.Tensor | None = None,
157
        residual: torch.Tensor | None = None,
158
        variance_size_override: int | None = None,
159
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
160
        """PyTorch-native implementation equivalent to forward()."""
zhuwenwen's avatar
zhuwenwen committed
161
162
        if not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
            return self.forward_cuda(x, residual)  
163
        else:
zhuwenwen's avatar
zhuwenwen committed
164
165
166
            orig_dtype = x.dtype
            x = x.to(torch.float32)
            if residual is not None:
167
168
169
170
                # residual promoted f16->f32 automatically,
                # otherwise Inductor eliminates the casts to and from f16,
                # increasing memory usage (and complicating pattern matching)
                x = x + residual
zhuwenwen's avatar
zhuwenwen committed
171
172
                residual = x.to(orig_dtype)

173
            if x.shape[-1] != hidden_size:
174
                raise ValueError(
175
                    f"Expected hidden_size to be {hidden_size}, but found: {x.shape[-1]}"
176
                )
zhuwenwen's avatar
zhuwenwen committed
177

178
            if variance_size_override is None:
zhuwenwen's avatar
zhuwenwen committed
179
180
                x_var = x
            else:
181
                if hidden_size < variance_size_override:
zhuwenwen's avatar
zhuwenwen committed
182
183
                    raise ValueError(
                        "Expected hidden_size to be at least "
184
185
                        f"{variance_size_override}, but found: {hidden_size}"
                    )
zhuwenwen's avatar
zhuwenwen committed
186

187
                x_var = x[:, :, :variance_size_override]
zhuwenwen's avatar
zhuwenwen committed
188
189
190

            variance = x_var.pow(2).mean(dim=-1, keepdim=True)

191
192

            x = x * torch.rsqrt(variance + variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
193
            x = x.to(orig_dtype)
194
195
            if weight is not None:
                x = x * weight
zhuwenwen's avatar
zhuwenwen committed
196
197
198
199
            if residual is None:
                return x
            else:
                return x, residual
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def forward_native(
        self,
        x: torch.Tensor,
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
        """PyTorch-native implementation equivalent to forward()."""

        return self.forward_static(
            x,
            self.variance_epsilon,
            self.hidden_size,
            x.dtype,
            self.weight.data if self.has_weight else None,
            residual,
            self.variance_size_override,
        )

218
    def forward_cuda(
219
220
        self,
        x: torch.Tensor,
221
222
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
223
224
225
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

226
        add_residual = residual is not None
227
        if add_residual:
228
229
230
            return fused_add_rms_norm(
                x, residual, self.weight.data, self.variance_epsilon
            )
231
232
        else:
            return rms_norm(x, self.weight.data, self.variance_epsilon)
233

234
235
236
    def forward_hip(
        self,
        x: torch.Tensor,
237
238
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
239
240
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
241

242
        add_residual = residual is not None
243
        if add_residual:
244
245
246
            return self.rocm_norm_func_with_add(
                x, residual, self.weight.data, self.variance_epsilon
            )
zhuwenwen's avatar
zhuwenwen committed
247
        else:
248
            return self.rocm_norm_func(x, self.weight.data, self.variance_epsilon)
249
           
zhuwenwen's avatar
zhuwenwen committed
250
251
252
    def forward_apex(
        self,
        x: torch.Tensor,
zhuwenwen's avatar
zhuwenwen committed
253
254
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
255
256
257
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)
        
zhuwenwen's avatar
zhuwenwen committed
258
259
260
261
        from apex.normalization.fused_layer_norm import fused_rms_norm_affine
        add_residual = residual is not None

        if add_residual:
zhuwenwen's avatar
zhuwenwen committed
262
263
            return self.rocm_norm_func_with_add(x, residual, self.weight.data,
                                                self.variance_epsilon)
zhuwenwen's avatar
zhuwenwen committed
264
265
        else:
            return fused_rms_norm_affine(x, self.weight.data, torch.Size((x.shape[-1],)), self.variance_epsilon)
266

267
268
269
    def forward_xpu(
        self,
        x: torch.Tensor,
270
271
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
272
273
274
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

275
276
277
278
279
280
281
282
283
284
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
285
        return ops.rms_norm(
286
287
288
289
290
            x,
            self.weight.data,
            self.variance_epsilon,
        )

291
292
293
294
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296


297
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

315
316
317
318
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
319
        x: torch.Tensor,
320
321
        residual: torch.Tensor | None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
325
326
327
328
329
            x = (
                x.float() + residual.float()
                if orig_dtype == torch.float16
                else x + residual
            )
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332
333
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
334
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
337
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
338
339
340
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

341
342
343
    def forward_native(
        self,
        x: torch.Tensor,
344
345
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
346
        """PyTorch-native implementation equivalent to forward()."""
347
        return self.forward_static(self.weight.data, self.variance_epsilon, x, residual)
348

Woosuk Kwon's avatar
Woosuk Kwon committed
349
350
351
    def forward_cuda(
        self,
        x: torch.Tensor,
352
353
        residual: torch.Tensor | None = None,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
354
355
356
357
358
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
359
360
                self.forward_static
            )
361
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
362
        return self.forward_native(x, residual)
363
364


365
366
367
@CustomOp.register("rms_norm_gated")
class RMSNormGated(CustomOp):
    """RMS Normalization with optional gating.
368

369
370
371
372
    This is a native PyTorch implementation that supports:
    - Standard RMS normalization
    - Group RMS normalization
    - Optional gating with SiLU activation
373
374
375
376
    """

    def __init__(
        self,
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        hidden_size: int,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """Initialize RMSNormGated.

        Args:
            hidden_size: Size of the hidden dimension
            eps: Epsilon for numerical stability
            group_size: If not None, do GroupNorm with each group
                        having group_size elements.
                        group_size=None is equivalent to group_size=hidden_size
                        (i.e. there's only 1 group).
            norm_before_gate: If True and z is provided: out = norm(x) * silu(z)
                              If False and z is provided: out = norm(x * silu(z))
            device: Device to create parameters on
            dtype: Data type for parameters
        """
        factory_kwargs = {"device": device, "dtype": dtype}
399
        super().__init__()
400
401
402
403
404
405
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()
406

407
408
    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
409
410

    def forward_native(
411
        self, x: torch.Tensor, z: torch.Tensor | None = None
412
    ) -> torch.Tensor:
413
414
415
416
417
418
        """
        Native PyTorch implementation of RMS normalization with gating.

        Args:
            x: Input tensor
            z: Optional gating tensor
419

420
421
422
423
424
425
        Returns:
            Normalized (and optionally gated) tensor

        If z is not None:
            - norm_before_gate=True: out = norm(x) * silu(z)
            - norm_before_gate=False: out = norm(x * silu(z))
426
        """
427
428
429
430
431
432
433
434
435
436
437
438
439
        # Apply gating before normalization if needed
        if z is not None and not self.norm_before_gate:
            x = x * F.silu(z)

        # RMS Normalization
        if self.group_size is None:
            # Standard RMS norm across the last dimension
            variance = x.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x * torch.rsqrt(variance + self.eps)
            out = x_normed * self.weight
        else:
            # Group RMS norm
            from einops import rearrange
440

441
442
443
444
445
446
447
448
449
450
            x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
            variance = x_group.pow(2).mean(dim=-1, keepdim=True)
            x_normed = x_group * torch.rsqrt(variance + self.eps)
            out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight

        # Apply gating after normalization if needed
        if z is not None and self.norm_before_gate:
            out = out * F.silu(z)

        return out
451
452

    def forward_cuda(
453
        self, x: torch.Tensor, z: torch.Tensor | None = None
454
    ) -> torch.Tensor:
455
456
457
458
459
460
461
462
463
464
465
        from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn

        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480


class LayerNorm(nn.Module):
    """
    Layer Normalization.
    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))

    def forward(self, x: torch.Tensor):
481
482
483
        return F.layer_norm(
            x.float(), (self.dim,), self.weight, self.bias, self.eps
        ).type_as(x)