rms_norm.py 3.08 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
# Copyright (c) 2022, Tri Dao.
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py

import torch
from torch.nn import init

from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn
8
from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn
Tri Dao's avatar
Tri Dao committed
9
10
11
12
13
14
15


def rms_norm(x, weight, epsilon):
    return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False,
                                       False, True)


Tri Dao's avatar
Tri Dao committed
16
17
18
19
20
def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None,
                         layerscale=None, prenorm=False, residual_in_fp32=False,
                         return_dropout_mask=False):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
21
22
    """
    return DropoutAddLayerNormFn.apply(
Tri Dao's avatar
Tri Dao committed
23
        x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm,
Tri Dao's avatar
Tri Dao committed
24
25
26
27
        True, return_dropout_mask
    )


Tri Dao's avatar
Tri Dao committed
28
def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None,
Tri Dao's avatar
Tri Dao committed
29
30
31
                                  x0_subset=None, out_subset=None, rowscale_const=1.0,
                                  out_numrows=0, prenorm=False, residual_in_fp32=False,
                                  return_dropout_mask=False):
Tri Dao's avatar
Tri Dao committed
32
33
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
Tri Dao's avatar
Tri Dao committed
34
35
    """
    return DropoutAddLayerNormSubsetFn.apply(
Tri Dao's avatar
Tri Dao committed
36
        x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon,
Tri Dao's avatar
Tri Dao committed
37
38
39
40
        rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask
    )


41
42
43
44
45
46
47
48
49
50
51
52
53
def dropout_add_rms_norm_parallel_residual(
   x0, x1, residual, weight0, bias0, weight1, bias1,
   dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
    """residual_in_fp32 only has an effect if residual is None.
    Otherwise residual dtype is residual.dtype.
    """
    return DropoutAddLayerNormParallelResidualFn.apply(
        x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm,
        True, return_dropout_mask
    )


Tri Dao's avatar
Tri Dao committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class DropoutAddRMSNorm(torch.nn.Module):
    def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.prenorm = prenorm
        self.p = p
        self.epsilon = eps
        self.residual_in_fp32 = residual_in_fp32
        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.ones_(self.weight)

Tri Dao's avatar
Tri Dao committed
70
71
    def forward(self, x0, residual=None):
        return dropout_add_rms_norm(x0, residual, self.weight, None,
Tri Dao's avatar
Tri Dao committed
72
73
                                    self.p if self.training else 0.0, self.epsilon,
                                    prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32)