rms_norm.py 3.89 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py

import torch
from torch.nn import init

Tri Dao's avatar
Tri Dao committed
7
8
9
10
11
from flash_attn.ops.layer_norm import (
    DropoutAddLayerNormFn,
    DropoutAddLayerNormParallelResidualFn,
    DropoutAddLayerNormSubsetFn,
)
Tri Dao's avatar
Tri Dao committed
12
13
14


def rms_norm(x, weight, epsilon):
Tri Dao's avatar
Tri Dao committed
15
16
17
    return DropoutAddLayerNormFn.apply(
        x, None, weight, None, None, None, 0.0, epsilon, False, False, True
    )
Tri Dao's avatar
Tri Dao committed
18
19


Tri Dao's avatar
Tri Dao committed
20
21
22
23
24
25
26
27
28
29
30
31
32
def dropout_add_rms_norm(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    rowscale=None,
    layerscale=None,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
Tri Dao's avatar
Tri Dao committed
33
34
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
35
36
    """
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
45
46
47
48
        x0,
        residual,
        weight,
        bias,
        rowscale,
        layerscale,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        True,
        return_dropout_mask,
Tri Dao's avatar
Tri Dao committed
49
50
51
    )


Tri Dao's avatar
Tri Dao committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def dropout_add_rms_norm_subset(
    x0,
    residual,
    weight,
    bias,
    dropout_p,
    epsilon,
    layerscale=None,
    x0_subset=None,
    out_subset=None,
    rowscale_const=1.0,
    out_numrows=0,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
):
Tri Dao's avatar
Tri Dao committed
68
69
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
70
71
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        x0,
        residual,
        weight,
        bias,
        layerscale,
        x0_subset,
        out_subset,
        dropout_p,
        epsilon,
        rowscale_const,
        out_numrows,
        residual_in_fp32,
        prenorm,
        True,
        return_dropout_mask,
Tri Dao's avatar
Tri Dao committed
87
88
89
    )


90
def dropout_add_rms_norm_parallel_residual(
Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97
98
99
100
101
102
    x0,
    x1,
    residual,
    weight0,
    bias0,
    weight1,
    bias1,
    dropout_p,
    epsilon,
    prenorm=False,
    residual_in_fp32=False,
    return_dropout_mask=False,
103
104
105
106
107
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
114
115
116
117
118
119
120
        x0,
        x1,
        residual,
        weight0,
        bias0,
        weight1,
        bias1,
        dropout_p,
        epsilon,
        residual_in_fp32,
        prenorm,
        True,
        return_dropout_mask,
121
122
123
    )


Tri Dao's avatar
Tri Dao committed
124
125
class RMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
Tri Dao's avatar
Tri Dao committed
126
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
127
128
129
        super().__init__()
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
Tri Dao's avatar
Tri Dao committed
130
        self.register_parameter("bias", None)
Tri Dao's avatar
Tri Dao committed
131
132
133
134
135
136
137
138
139
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

    def forward(self, x):
        return rms_norm(x, self.weight, self.eps)


Tri Dao's avatar
Tri Dao committed
140
class DropoutAddRMSNorm(torch.nn.Module):
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
150
151
    def __init__(
        self,
        hidden_size,
        prenorm=False,
        p=0.0,
        eps=1e-5,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
Tri Dao's avatar
Tri Dao committed
152
153
154
        super().__init__()
        self.prenorm = prenorm
        self.p = p
Tri Dao's avatar
Tri Dao committed
155
        self.eps = eps
Tri Dao's avatar
Tri Dao committed
156
157
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
Tri Dao's avatar
Tri Dao committed
158
        self.register_parameter("bias", None)
Tri Dao's avatar
Tri Dao committed
159
160
161
162
163
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

Tri Dao's avatar
Tri Dao committed
164
    def forward(self, x0, residual=None):
Tri Dao's avatar
Tri Dao committed
165
166
167
168
169
170
171
172
173
174
        return dropout_add_rms_norm(
            x0,
            residual,
            self.weight,
            None,
            self.p if self.training else 0.0,
            self.eps,
            prenorm=self.prenorm,
            residual_in_fp32=self.residual_in_fp32,
        )