layernorm.py 7.2 KB
Newer Older
1
"""Custom normalization layers."""
2
3
from typing import Optional, Tuple, Union

4
5
6
import torch
import torch.nn as nn

7
from vllm.model_executor.custom_op import CustomOp
zhuwenwen's avatar
zhuwenwen committed
8
import vllm.envs as envs
9
10


11
@CustomOp.register("rms_norm")
12
class RMSNorm(CustomOp):
13
14
15
16
17
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
18
19
20
21
22

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
23
        var_hidden_size: Optional[int] = None,
24
        has_weight: bool = True,
25
26
    ) -> None:
        super().__init__()
27
28

        self.hidden_size = hidden_size
29
        self.variance_epsilon = eps
30
31
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
32
33
34
35
36
        self.has_weight = has_weight

        self.weight = torch.ones(hidden_size)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
37

38
    def forward_native(
39
40
41
42
43
44
45
46
47
48
49
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

67
        x = x * torch.rsqrt(variance + self.variance_epsilon)
68
69
70
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
71
72
73
74
75
        if residual is None:
            return x
        else:
            return x, residual

76
    def forward_cuda(
77
78
79
80
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
81
82
83
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

84
85
        from vllm import _custom_ops as ops

86
        if residual is not None:
zhuwenwen's avatar
zhuwenwen committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
            if envs.VLLM_USE_OPT_OP:
                ops.fused_add_rms_norm_opt(
                    x,
                    residual,
                    self.weight.data,
                    self.variance_epsilon,
                )
            else:
                ops.fused_add_rms_norm(
                    x,
                    residual,
                    self.weight.data,
                    self.variance_epsilon,
                )
            return x, residual
        out = torch.empty_like(x)
        if envs.VLLM_USE_OPT_OP:
            ops.rms_norm_opt(
                out,
                x,
                self.weight.data,
                self.variance_epsilon,
            )
        else:
            ops.rms_norm(
                out,
113
114
115
116
                x,
                self.weight.data,
                self.variance_epsilon,
            )
117
        return out
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from vllm_hpu_extension.ops import HPUFusedRMSNorm
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

138
139
140
141
142
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
143
144
145
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

146
147
148
149
150
151
152
153
154
155
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
156
        return ops.rms_norm(
157
158
159
160
161
            x,
            self.weight.data,
            self.variance_epsilon,
        )

162
163
164
165
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
166
167


168
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

186
187
188
189
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
190
        x: torch.Tensor,
191
        residual: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
192
193
194
195
196
197
198
199
200
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
201
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
202
203
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
204
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

208
209
210
211
212
213
214
215
216
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
217
218
219
220
221
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
222
223
224
225
226
227
228
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
229
        return self.forward_native(x, residual)