layernorm.py 1.78 KB
Newer Older
1
"""Custom normalization layers."""
2
3
from typing import Optional, Tuple, Union

4
5
6
import torch
import torch.nn as nn

7
from vllm._C import ops
8
9
10


class RMSNorm(nn.Module):
11
12
13
14
15
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
16
17
18
19
20
21
22
23
24
25

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    def _forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x.to(orig_dtype) * self.weight
        if residual is None:
            return x
        else:
            return x, residual

46
47
48
49
50
51
    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if residual is not None:
52
            ops.fused_add_rms_norm(
53
54
55
56
57
58
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
59
        out = torch.empty_like(x)
60
        ops.rms_norm(
61
62
63
64
65
66
            out,
            x,
            self.weight.data,
            self.variance_epsilon,
        )
        return out