rms_norm.py 1.18 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import torch
from torch import nn
wangxj's avatar
wangxj committed
5

wxj's avatar
wxj committed
6
7
import torch._dynamo
torch._dynamo.config.suppress_errors = True
xingjinliang's avatar
xingjinliang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

class RMSNorm(torch.nn.Module):

    def __init__(self,
                 dim: int,
                 eps: float = 1e-6,
                 sequence_parallel: bool = False,
                 config: dict = None):
        """RMS Normaliation module

        Args:
            dim (int): The width of input, i.e. hidden size
            eps (float): epsilon to use for the norm, default to 1e-6
            sequence_parallel (bool): Set to true if sequence parallelism is being used,
              this marks the weights as needing to be allreduced.
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

        setattr(self.weight, 'sequence_parallel', sequence_parallel)

wxj's avatar
wxj committed
30
    @torch.compile(mode="max-autotune-no-cudagraphs")
xingjinliang's avatar
xingjinliang committed
31
32
33
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

wxj's avatar
wxj committed
34
    @torch.compile(mode="max-autotune-no-cudagraphs")
xingjinliang's avatar
xingjinliang committed
35
36
37
    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight