layernorm.py 5.64 KB
Newer Older
Nicolas Patry's avatar
Nicolas Patry committed
1
2
3
4
5
import torch
from torch import nn
from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
    SYSTEM,
Wang, Yi's avatar
Wang, Yi committed
6
    IPEX_AVAIL,
Nicolas Patry's avatar
Nicolas Patry committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
)


# Monkey patching
@classmethod
def load_layer_norm(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    bias = weights.get_tensor(f"{prefix}.bias")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = torch.nn.Parameter(weight)
    ln.bias = torch.nn.Parameter(bias)
    return ln


@classmethod
def load_layer_norm_no_bias(cls, prefix, weights, eps):
    weight = weights.get_tensor(f"{prefix}.weight")
    with init_empty_weights():
        ln = cls(weight.shape, eps=eps)

    ln.weight = torch.nn.Parameter(weight)
    ln.bias = None
    return ln


torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias

if SYSTEM == "cuda":
    import dropout_layer_norm

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if hidden_states.shape[-1] > 8192:
                if residual is not None:
                    hidden_states += residual
                residual = hidden_states

                return super(FastLayerNorm, self).forward(hidden_states), residual
            else:
                (
                    normed_hidden_states,
                    residual,
                    *rest,
                ) = dropout_layer_norm.dropout_add_ln_fwd(
                    hidden_states,
                    residual,
                    self.weight,
                    self.bias,
                    None,
                    None,
                    None,
                    None,
                    0.0,
                    self.eps,
                    1.0,
                    0,
                    None,
                    False,
                    False,
                )
                if residual is None:
                    residual = hidden_states

                return normed_hidden_states, residual

elif SYSTEM == "rocm":
fxmarty's avatar
fxmarty committed
76
    from vllm._C import ops
Nicolas Patry's avatar
Nicolas Patry committed
77
78
79
80
81
82
83
84
85

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            if residual is not None:
                hidden_states += residual
            residual = hidden_states

            return super().forward(hidden_states), residual

Wang, Yi's avatar
Wang, Yi committed
86
elif IPEX_AVAIL:
Nicolas Patry's avatar
Nicolas Patry committed
87
88
89
90
91
    import intel_extension_for_pytorch as ipex

    class FastLayerNorm(nn.LayerNorm):
        def forward(self, hidden_states, residual=None):
            out = ipex.llm.functional.add_layer_norm(
Wang, Yi's avatar
Wang, Yi committed
92
93
94
95
96
97
                residual,
                hidden_states,
                self.weight,
                self.bias,
                self.eps,
                residual is not None,
Nicolas Patry's avatar
Nicolas Patry committed
98
            )
Wang, Yi's avatar
Wang, Yi committed
99
            return out, residual if residual is not None else hidden_states
Nicolas Patry's avatar
Nicolas Patry committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114


class FastRMSNorm(nn.Module):
    def __init__(self, weight: torch.Tensor, eps: float):
        super().__init__()

        self.weight = nn.Parameter(weight)
        self.variance_epsilon = eps

    @classmethod
    def load(cls, prefix, weights, eps=1e-6):
        weight = weights.get_tensor(f"{prefix}.weight")
        return cls(weight, eps)

    def forward(self, hidden_states, residual=None):
Wang, Yi's avatar
Wang, Yi committed
115
        if IPEX_AVAIL:
Nicolas Patry's avatar
Nicolas Patry committed
116
117
118
119
120
121
            out = ipex.llm.functional.add_rms_norm(
                residual,
                hidden_states,
                self.weight,
                None,
                self.variance_epsilon,
Wang, Yi's avatar
Wang, Yi committed
122
                residual is not None,
Nicolas Patry's avatar
Nicolas Patry committed
123
            )
Wang, Yi's avatar
Wang, Yi committed
124
            return out, residual if residual is not None else hidden_states
Nicolas Patry's avatar
Nicolas Patry committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        elif hidden_states.shape[-1] > 8192:
            if residual is not None:
                hidden_states += residual
            residual = hidden_states

            hidden_states = hidden_states.to(torch.float32)
            variance = hidden_states.pow(2).mean(-1, keepdim=True)
            hidden_states = hidden_states * torch.rsqrt(
                variance + self.variance_epsilon
            )

            # convert into half-precision if necessary
            if self.weight.dtype in [torch.float16, torch.bfloat16]:
                hidden_states = hidden_states.to(self.weight.dtype)

            return self.weight * hidden_states, residual
        elif SYSTEM == "cuda":
            # faster post attention rms norm
            (
                normed_hidden_states,
                res,
                *rest,
            ) = dropout_layer_norm.dropout_add_ln_fwd(
                hidden_states,
                residual,
                self.weight,
                None,
                None,
                None,
                None,
                None,
                0.0,
                self.variance_epsilon,
                1.0,
                0,
                None,
                False,
                True,  # Activate RMSNorm
            )
            if res is None:
                res = hidden_states

            return normed_hidden_states, res
        elif SYSTEM == "rocm":
            # We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
            if residual is not None:
                hidden_states += residual
            residual = hidden_states

            out = torch.empty_like(hidden_states)
fxmarty's avatar
fxmarty committed
175
            ops.rms_norm(
Nicolas Patry's avatar
Nicolas Patry committed
176
177
178
179
180
181
182
183
184
185
                out,
                hidden_states,
                self.weight.data,
                self.variance_epsilon,
            )
            return out, residual
        else:
            raise ValueError(
                "Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
            )