layernorm.py 5.76 KB
Newer Older
1
"""Custom normalization layers."""
2
3
from typing import Optional, Tuple, Union

4
5
6
import torch
import torch.nn as nn

7
from vllm.model_executor.custom_op import CustomOp
8
9


10
@CustomOp.register("rms_norm")
11
class RMSNorm(CustomOp):
12
13
14
15
16
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
17
18
19
20
21

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
22
        var_hidden_size: Optional[int] = None,
23
24
    ) -> None:
        super().__init__()
25
26

        self.hidden_size = hidden_size
27
        self.variance_epsilon = eps
28
29
30
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
        self.weight = nn.Parameter(torch.ones(hidden_size))
31

32
    def forward_native(
33
34
35
36
37
38
39
40
41
42
43
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

61
62
63
64
65
66
67
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x.to(orig_dtype) * self.weight
        if residual is None:
            return x
        else:
            return x, residual

68
    def forward_cuda(
69
70
71
72
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
73
74
75
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

76
77
        from vllm import _custom_ops as ops

78
        if residual is not None:
79
            ops.fused_add_rms_norm(
80
81
82
83
84
85
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
86
        out = torch.empty_like(x)
87
        ops.rms_norm(
88
89
90
91
92
93
            out,
            x,
            self.weight.data,
            self.variance_epsilon,
        )
        return out
94

95
96
97
98
99
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
100
101
102
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

103
104
105
106
107
108
109
110
111
112
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
113
        return ops.rms_norm(
114
115
116
117
118
            x,
            self.weight.data,
            self.variance_epsilon,
        )

119
120
121
122
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124


125
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

143
144
145
146
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        x: torch.Tensor,
148
        residual: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
149
150
151
152
153
154
155
156
157
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
158
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
161
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
162
163
164
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

165
166
167
168
169
170
171
172
173
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
179
180
181
182
183
184
185
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
186
        return self.forward_native(x, residual)