block.py 8.01 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

Casper's avatar
Casper committed
5

6
7
class MixtralBlock(nn.Module):
    def __init__(
Casper's avatar
Casper committed
8
9
10
11
12
13
14
15
16
17
18
19
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        moe,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta,
20
21
22
23
24
25
26
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
27
28
29
30
31
32
33
34
35
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            rope_theta=rope_theta,
36
37
38
39
40
41
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.moe = moe
        self.device = dev

    def forward(
Casper's avatar
Casper committed
42
43
44
45
46
47
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
48
49
50
51
52
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
53
            attention_mask=attention_mask,
54
55
56
        )

        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
57
        out = self.moe.forward(self.norm_2(h))
58
59
60
61
        out = h + out

        return out, None, past_key_value

Casper's avatar
Casper committed
62

Casper's avatar
Casper committed
63
64
65
66
67
class LlamaLikeBlock(nn.Module):
    """
    LlamaLikeBlock is intended to be reused across blocks that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
68

Casper's avatar
Casper committed
69
    def __init__(
Casper's avatar
Casper committed
70
71
72
73
74
75
76
77
78
79
80
81
82
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta=10000,
        use_alibi=False,
TechxGenus's avatar
TechxGenus committed
83
        head_dim=None,
Casper's avatar
Casper committed
84
    ):
Casper's avatar
Casper committed
85
86
87
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
TechxGenus's avatar
TechxGenus committed
88
89
90
91
92
93
        self.head_dim = hidden_size // n_heads

        # To support gemma-7b, its head_dim is separate
        if head_dim:
            self.head_dim = head_dim

Casper's avatar
Casper committed
94
95
96
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
97
98
99
100
101
102
103
104
105
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=use_alibi,
            rope_theta=rope_theta,
TechxGenus's avatar
TechxGenus committed
106
            head_dim=head_dim,
Casper's avatar
Casper committed
107
108
109
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.mlp = mlp.to(dev)
110
        self.device = dev
Casper's avatar
Casper committed
111
112

    def forward(
Casper's avatar
Casper committed
113
114
115
116
117
118
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper's avatar
Casper committed
119
120
121
122
123
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
124
            attention_mask=attention_mask,
Casper's avatar
Casper committed
125
126
        )

127
        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
128
129
130
131
        out = h + self.mlp.forward(self.norm_2(h))

        return out, None, past_key_value

Casper's avatar
Casper committed
132

133
class MPTBlock(nn.Module):
Casper's avatar
Casper committed
134
135
136
137
138
139
140
141
142
143
144
145
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mpt_mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
    ):
Casper Hansen's avatar
Casper Hansen committed
146
147
        super().__init__()
        self.n_heads = n_heads
Casper Hansen's avatar
Casper Hansen committed
148
        self.n_kv_heads = 0
Casper Hansen's avatar
Casper Hansen committed
149
        self.hidden_size = hidden_size
150
        self.norm_1 = norm_1
Casper Hansen's avatar
Casper Hansen committed
151
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
152
153
154
155
156
157
158
159
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=True,
Casper Hansen's avatar
Casper Hansen committed
160
        ).to(dev)
161
162
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
163
        self.device = dev
Casper Hansen's avatar
Casper Hansen committed
164
165

    def forward(
Casper's avatar
Casper committed
166
167
168
169
170
171
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper Hansen's avatar
Casper Hansen committed
172
173
174
175
176
177
178
179
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
180
            use_cache=True,
Casper Hansen's avatar
Casper Hansen committed
181
182
        )

183
        h = hidden_states.to(attn_output.device) + attn_output
Casper Hansen's avatar
Casper Hansen committed
184
        out = h + self.ffn.forward(self.norm_2(h))
185
186
        return out, None, past_key_value

Casper's avatar
Casper committed
187

188
class FalconDecoderLayer(nn.Module):
Casper's avatar
Casper committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mlp,
        dev,
        max_seq_len,
        input_layernorm=None,
        ln_attn=None,
        ln_mlp=None,
        new_decoder_arch=True,
    ):
203
204
        super().__init__()
        self.n_heads = n_heads
205
        self.n_kv_heads = 8 if new_decoder_arch else 0
206
207
        self.hidden_size = hidden_size
        self.new_decoder_arch = new_decoder_arch
Casper Hansen's avatar
Casper Hansen committed
208
209
210
211

        if new_decoder_arch:
            attention_shapes = None
        else:
Casper's avatar
Casper committed
212
213
214
215
            attention_shapes = self._get_attention_shapes(
                n_heads, max_seq_len, self.hidden_size // n_heads
            )

216
217
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
218
219
220
221
222
223
224
225
226
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            attention_shapes=attention_shapes,
227
        ).to(dev)
Casper's avatar
Casper committed
228

229
        if new_decoder_arch:
Casper's avatar
Casper committed
230
231
            self.ln_attn = ln_attn  # before attention
            self.ln_mlp = ln_mlp  # before mlp
232
        else:
Casper's avatar
Casper committed
233
234
            self.input_layernorm = input_layernorm  # before attention

235
        self.mlp = mlp
236
        self.device = dev
Casper's avatar
Casper committed
237

Casper Hansen's avatar
Casper Hansen committed
238
    def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
239
        batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
240

Casper Hansen's avatar
Casper Hansen committed
241
242
        self.attention_shapes = {
            # following fastertransformer definition
Casper's avatar
Casper committed
243
244
245
246
247
248
            "cache_v": (
                batch_size,
                1,
                max_seq_len,
                head_dim,
            ),
Casper Hansen's avatar
Casper Hansen committed
249
            # 8: pack 8 fp16 in FT, if fp32 then use 4
Casper's avatar
Casper committed
250
251
252
253
254
255
256
257
            "cache_k": (
                batch_size,
                1,
                head_dim // 8,
                max_seq_len,
                8,
            ),
            "xqkv_view": (n_heads + 2, head_dim),
Casper Hansen's avatar
Casper Hansen committed
258
259
260
261
262
263
264
265
266
            "xq_slice": lambda xqkv: xqkv[:, :, :-2],
            "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
            "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
            "xq_view": (n_heads, head_dim),
            "xk_view": (1, head_dim),
            "xv_view": (1, head_dim),
            "xk_reshape": (1, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (1, head_dim),
Casper's avatar
Casper committed
267
            "single_xv_view": (1, head_dim),
Casper Hansen's avatar
Casper Hansen committed
268
        }
269
270

        return self.attention_shapes
271
272

    def forward(
Casper's avatar
Casper committed
273
274
275
276
277
278
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
279
280
281
282
283
284
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
Casper's avatar
Casper committed
285

286
287
288
289
290
291
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
292
            use_cache=True,
293
294
        )

295
        h_attn = hidden_states.to(attn_output.device) + attn_output
296
297
298
299
300

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
Casper's avatar
Casper committed
301

302
        out = h_attn + h_mlp
Casper's avatar
Casper committed
303
304

        return out, None, past_key_value