block.py 8.1 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

Casper's avatar
Casper committed
5

6
7
class MixtralBlock(nn.Module):
    def __init__(
Casper's avatar
Casper committed
8
9
10
11
12
13
14
15
16
17
18
19
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        moe,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta,
20
21
22
23
24
25
26
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
27
28
29
30
31
32
33
34
35
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            rope_theta=rope_theta,
36
37
38
39
40
41
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.moe = moe
        self.device = dev

    def forward(
Casper's avatar
Casper committed
42
43
44
45
46
47
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
48
49
50
51
52
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
53
            attention_mask=attention_mask,
54
55
56
        )

        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
57
        out = self.moe.forward(self.norm_2(h))
58
59
60
61
        out = h + out

        return out, None, past_key_value

Casper's avatar
Casper committed
62

Casper's avatar
Casper committed
63
64
65
66
67
class LlamaLikeBlock(nn.Module):
    """
    LlamaLikeBlock is intended to be reused across blocks that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
68

Casper's avatar
Casper committed
69
    def __init__(
Casper's avatar
Casper committed
70
71
72
73
74
75
76
77
78
79
80
81
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta=10000,
Isotr0py's avatar
Isotr0py committed
82
        partial_rotary_factor=1.0,
Casper's avatar
Casper committed
83
        use_alibi=False,
TechxGenus's avatar
TechxGenus committed
84
        head_dim=None,
Casper's avatar
Casper committed
85
    ):
Casper's avatar
Casper committed
86
87
88
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
TechxGenus's avatar
TechxGenus committed
89
90
91
92
93
94
        self.head_dim = hidden_size // n_heads

        # To support gemma-7b, its head_dim is separate
        if head_dim:
            self.head_dim = head_dim

Casper's avatar
Casper committed
95
96
97
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
98
99
100
101
102
103
104
105
106
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=use_alibi,
            rope_theta=rope_theta,
Isotr0py's avatar
Isotr0py committed
107
            partial_rotary_factor=partial_rotary_factor,
TechxGenus's avatar
TechxGenus committed
108
            head_dim=head_dim,
Casper's avatar
Casper committed
109
110
111
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.mlp = mlp.to(dev)
112
        self.device = dev
Casper's avatar
Casper committed
113
114

    def forward(
Casper's avatar
Casper committed
115
116
117
118
119
120
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper's avatar
Casper committed
121
122
123
124
125
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
126
            attention_mask=attention_mask,
Casper's avatar
Casper committed
127
128
        )

129
        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
130
131
132
133
        out = h + self.mlp.forward(self.norm_2(h))

        return out, None, past_key_value

Casper's avatar
Casper committed
134

135
class MPTBlock(nn.Module):
Casper's avatar
Casper committed
136
137
138
139
140
141
142
143
144
145
146
147
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mpt_mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
    ):
Casper Hansen's avatar
Casper Hansen committed
148
149
        super().__init__()
        self.n_heads = n_heads
Casper Hansen's avatar
Casper Hansen committed
150
        self.n_kv_heads = 0
Casper Hansen's avatar
Casper Hansen committed
151
        self.hidden_size = hidden_size
152
        self.norm_1 = norm_1
Casper Hansen's avatar
Casper Hansen committed
153
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
154
155
156
157
158
159
160
161
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=True,
Casper Hansen's avatar
Casper Hansen committed
162
        ).to(dev)
163
164
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
165
        self.device = dev
Casper Hansen's avatar
Casper Hansen committed
166
167

    def forward(
Casper's avatar
Casper committed
168
169
170
171
172
173
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper Hansen's avatar
Casper Hansen committed
174
175
176
177
178
179
180
181
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
182
            use_cache=True,
Casper Hansen's avatar
Casper Hansen committed
183
184
        )

185
        h = hidden_states.to(attn_output.device) + attn_output
Casper Hansen's avatar
Casper Hansen committed
186
        out = h + self.ffn.forward(self.norm_2(h))
187
188
        return out, None, past_key_value

Casper's avatar
Casper committed
189

190
class FalconDecoderLayer(nn.Module):
Casper's avatar
Casper committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mlp,
        dev,
        max_seq_len,
        input_layernorm=None,
        ln_attn=None,
        ln_mlp=None,
        new_decoder_arch=True,
    ):
205
206
        super().__init__()
        self.n_heads = n_heads
207
        self.n_kv_heads = 8 if new_decoder_arch else 0
208
209
        self.hidden_size = hidden_size
        self.new_decoder_arch = new_decoder_arch
Casper Hansen's avatar
Casper Hansen committed
210
211
212
213

        if new_decoder_arch:
            attention_shapes = None
        else:
Casper's avatar
Casper committed
214
215
216
217
            attention_shapes = self._get_attention_shapes(
                n_heads, max_seq_len, self.hidden_size // n_heads
            )

218
219
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
220
221
222
223
224
225
226
227
228
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            attention_shapes=attention_shapes,
229
        ).to(dev)
Casper's avatar
Casper committed
230

231
        if new_decoder_arch:
Casper's avatar
Casper committed
232
233
            self.ln_attn = ln_attn  # before attention
            self.ln_mlp = ln_mlp  # before mlp
234
        else:
Casper's avatar
Casper committed
235
236
            self.input_layernorm = input_layernorm  # before attention

237
        self.mlp = mlp
238
        self.device = dev
Casper's avatar
Casper committed
239

Casper Hansen's avatar
Casper Hansen committed
240
    def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
241
        batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
242

Casper Hansen's avatar
Casper Hansen committed
243
244
        self.attention_shapes = {
            # following fastertransformer definition
Casper's avatar
Casper committed
245
246
247
248
249
250
            "cache_v": (
                batch_size,
                1,
                max_seq_len,
                head_dim,
            ),
Casper Hansen's avatar
Casper Hansen committed
251
            # 8: pack 8 fp16 in FT, if fp32 then use 4
Casper's avatar
Casper committed
252
253
254
255
256
257
258
259
            "cache_k": (
                batch_size,
                1,
                head_dim // 8,
                max_seq_len,
                8,
            ),
            "xqkv_view": (n_heads + 2, head_dim),
Casper Hansen's avatar
Casper Hansen committed
260
261
262
263
264
265
266
267
268
            "xq_slice": lambda xqkv: xqkv[:, :, :-2],
            "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
            "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
            "xq_view": (n_heads, head_dim),
            "xk_view": (1, head_dim),
            "xv_view": (1, head_dim),
            "xk_reshape": (1, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (1, head_dim),
Casper's avatar
Casper committed
269
            "single_xv_view": (1, head_dim),
Casper Hansen's avatar
Casper Hansen committed
270
        }
271
272

        return self.attention_shapes
273
274

    def forward(
Casper's avatar
Casper committed
275
276
277
278
279
280
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
281
282
283
284
285
286
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
Casper's avatar
Casper committed
287

288
289
290
291
292
293
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
294
            use_cache=True,
295
296
        )

297
        h_attn = hidden_states.to(attn_output.device) + attn_output
298
299
300
301
302

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
Casper's avatar
Casper committed
303

304
        out = h_attn + h_mlp
Casper's avatar
Casper committed
305
306

        return out, None, past_key_value