layernorm.py 8.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Custom normalization layers."""
3
4
from typing import Optional, Tuple, Union

5
6
7
import torch
import torch.nn as nn

8
import vllm.envs as envs
9
from vllm.model_executor.custom_op import CustomOp
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from vllm.platforms import current_platform


def is_rocm_aiter_rmsnorm_enabled() -> bool:
    return current_platform.is_rocm() \
        and envs.VLLM_ROCM_USE_AITER_RMSNORM \
        and envs.VLLM_ROCM_USE_AITER


def rms_norm(x: torch.Tensor, weight: torch.Tensor,
             variance_epsilon: float) -> torch.Tensor:
    from vllm import _custom_ops as ops
    out = torch.empty_like(x)
    ops.rms_norm(
        out,
        x,
        weight,
        variance_epsilon,
    )
    return out


def fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:
    from vllm import _custom_ops as ops
    ops.fused_add_rms_norm(
        x,
        residual,
        weight,
        variance_epsilon,
    )
    return x, residual


def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
                        variance_epsilon: float) -> torch.Tensor:

    import aiter as rocm_aiter
    return rocm_aiter.rms_norm(x, weight, variance_epsilon)


def rocm_aiter_fused_add_rms_norm(
        x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
        variance_epsilon: float) -> Tuple[torch.Tensor, torch.Tensor]:

    import aiter as rocm_aiter

    # Assuming the correct signature for rmsnorm2d_fwd_with_add
    rocm_aiter.rmsnorm2d_fwd_with_add(
        x,  # output
        x,  # input
        residual,  # residual input
        residual,  # residual output
        weight,
        variance_epsilon,
    )
    return x, residual


def dispatch_cuda_rmsnorm_func(add_residual: bool):
    if add_residual:
        if is_rocm_aiter_rmsnorm_enabled():
            return rocm_aiter_fused_add_rms_norm
        return fused_add_rms_norm

    if is_rocm_aiter_rmsnorm_enabled():
        return rocm_aiter_rms_norm
    return rms_norm
79
80


81
@CustomOp.register("rms_norm")
82
class RMSNorm(CustomOp):
83
84
85
86
87
    """Root mean square normalization.

    Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
    Refer to https://arxiv.org/abs/1910.07467
    """
88
89
90
91
92

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
93
        var_hidden_size: Optional[int] = None,
94
        has_weight: bool = True,
95
96
    ) -> None:
        super().__init__()
97
98

        self.hidden_size = hidden_size
99
        self.variance_epsilon = eps
100
101
        self.variance_size_override = (None if var_hidden_size == hidden_size
                                       else var_hidden_size)
102
103
104
105
106
        self.has_weight = has_weight

        self.weight = torch.ones(hidden_size)
        if self.has_weight:
            self.weight = nn.Parameter(self.weight)
107

108
    def forward_native(
109
110
111
112
113
114
115
116
117
118
119
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        hidden_size = x.shape[-1]
        if hidden_size != self.hidden_size:
            raise ValueError("Expected hidden_size to be "
                             f"{self.hidden_size}, but found: {hidden_size}")

        if self.variance_size_override is None:
            x_var = x
        else:
            if hidden_size < self.variance_size_override:
                raise ValueError(
                    "Expected hidden_size to be at least "
                    f"{self.variance_size_override}, but found: {hidden_size}")

            x_var = x[:, :, :self.variance_size_override]

        variance = x_var.pow(2).mean(dim=-1, keepdim=True)

137
        x = x * torch.rsqrt(variance + self.variance_epsilon)
138
139
140
        x = x.to(orig_dtype)
        if self.has_weight:
            x = x * self.weight
141
142
143
144
145
        if residual is None:
            return x
        else:
            return x, residual

146
    def forward_cuda(
147
148
149
150
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
151
152
153
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

154
155
        add_residual = residual is not None
        norm_func = dispatch_cuda_rmsnorm_func(add_residual)
156

157
158
159
160
161
        if add_residual:
            return norm_func(x, residual, self.weight.data,
                             self.variance_epsilon)
        else:
            return norm_func(x, self.weight.data, self.variance_epsilon)
162

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    def forward_hpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        from vllm_hpu_extension.ops import HPUFusedRMSNorm
        if HPUFusedRMSNorm is None:
            return self.forward_native(x, residual)
        if residual is not None:
            orig_shape = x.shape
            residual += x.view(residual.shape)
            # Note: HPUFusedRMSNorm requires 3D tensors as inputs
            x = HPUFusedRMSNorm.apply(residual, self.weight,
                                      self.variance_epsilon)
            return x.view(orig_shape), residual

        x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
        return x

182
183
184
185
186
    def forward_xpu(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
187
188
189
        if self.variance_size_override is not None:
            return self.forward_native(x, residual)

190
191
192
193
194
195
196
197
198
199
        from vllm._ipex_ops import ipex_ops as ops

        if residual is not None:
            ops.fused_add_rms_norm(
                x,
                residual,
                self.weight.data,
                self.variance_epsilon,
            )
            return x, residual
200
        return ops.rms_norm(
201
202
203
204
205
            x,
            self.weight.data,
            self.variance_epsilon,
        )

206
207
208
209
    def extra_repr(self) -> str:
        s = f"hidden_size={self.weight.data.size(0)}"
        s += f", eps={self.variance_epsilon}"
        return s
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211


212
@CustomOp.register("gemma_rms_norm")
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
class GemmaRMSNorm(CustomOp):
    """RMS normalization for Gemma.

    Two differences from the above RMSNorm:
        1. x * (1 + w) instead of x * w.
        2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
    """

    def __init__(
        self,
        hidden_size: int,
        eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(hidden_size))
        self.variance_epsilon = eps

230
231
232
233
    @staticmethod
    def forward_static(
        weight: torch.Tensor,
        variance_epsilon: float,
Woosuk Kwon's avatar
Woosuk Kwon committed
234
        x: torch.Tensor,
235
        residual: Optional[torch.Tensor],
Woosuk Kwon's avatar
Woosuk Kwon committed
236
237
238
239
240
241
242
243
244
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        orig_dtype = x.dtype
        if residual is not None:
            x = x + residual
            residual = x

        x = x.float()
        variance = x.pow(2).mean(dim=-1, keepdim=True)
245
        x = x * torch.rsqrt(variance + variance_epsilon)
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
        # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
248
        x = x * (1.0 + weight.float())
Woosuk Kwon's avatar
Woosuk Kwon committed
249
250
251
        x = x.to(orig_dtype)
        return x if residual is None else (x, residual)

252
253
254
255
256
257
258
259
260
    def forward_native(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        return self.forward_static(self.weight.data, self.variance_epsilon, x,
                                   residual)

Woosuk Kwon's avatar
Woosuk Kwon committed
261
262
263
264
265
    def forward_cuda(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
266
267
268
269
270
271
272
        if torch.compiler.is_compiling():
            return self.forward_native(x, residual)

        if not getattr(self, "_is_compiled", False):
            self.forward_static = torch.compile(  # type: ignore
                self.forward_static)
            self._is_compiled = True
Woosuk Kwon's avatar
Woosuk Kwon committed
273
        return self.forward_native(x, residual)