rms_norm.py 2.51 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py

import torch
from torch.nn import init

from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn


def rms_norm(x, weight, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
                                       False, True)


Tri Dao's avatar
Tri Dao committed
15
16
17
18
19
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                         layerscale=None, prenorm=False, residual_in_fp32=False,
                         return_dropout_mask=False):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
20
21
    """
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
22
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
23
24
25
26
        True, return_dropout_mask
    )


Tri Dao's avatar
Tri Dao committed
27
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
Tri Dao's avatar
Tri Dao committed
28
29
30
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
31
32
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
33
34
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
35
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
    )


class DropoutAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

Tri Dao's avatar
Tri Dao committed
56
57
    def forward(self, x0, residual=None):
        return dropout_add_rms_norm(x0, residual, self.weight, None,
Tri Dao's avatar
Tri Dao committed
58
59
                                    self.p if self.training else 0.0, self.epsilon,
                                    prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)