rms_norm.py 3.59 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py

import torch
from torch.nn import init

from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
8
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
Tri Dao's avatar
Tri Dao committed
9
10
11
12
13
14
15


def rms_norm(x, weight, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
                                       False, True)


Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                         layerscale=None, prenorm=False, residual_in_fp32=False,
                         return_dropout_mask=False):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
21
22
    """
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
23
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
24
25
26
27
        True, return_dropout_mask
    )


Tri Dao's avatar
Tri Dao committed
28
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
Tri Dao's avatar
Tri Dao committed
29
30
31
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
32
33
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
34
35
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
36
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
37
38
39
40
        rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
    )


41
42
43
44
45
46
47
48
49
50
51
52
53
def dropout_add_rms_norm_parallel_residual(
   x0, x1, residual, weight0, bias0, weight1, bias1,
   dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
        True, return_dropout_mask
    )


Tri Dao's avatar
Tri Dao committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, x):
        return rms_norm(x, self.weight, self.eps)


Tri Dao's avatar
Tri Dao committed
70
71
72
73
74
75
76
class DropoutAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
Tri Dao's avatar
Tri Dao committed
77
        self.eps = eps
Tri Dao's avatar
Tri Dao committed
78
79
80
81
82
83
84
85
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

Tri Dao's avatar
Tri Dao committed
86
87
    def forward(self, x0, residual=None):
        return dropout_add_rms_norm(x0, residual, self.weight, None,
Tri Dao's avatar
Tri Dao committed
88
                                    self.p if self.training else 0.0, self.eps,
Tri Dao's avatar
Tri Dao committed
89
                                    prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)