block.py 7.8 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

Casper's avatar
Casper committed
5

6
7
class MixtralBlock(nn.Module):
    def __init__(
Casper's avatar
Casper committed
8
9
10
11
12
13
14
15
16
17
18
19
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        moe,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta,
20
21
22
23
24
25
26
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
27
28
29
30
31
32
33
34
35
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            rope_theta=rope_theta,
36
37
38
39
40
41
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.moe = moe
        self.device = dev

    def forward(
Casper's avatar
Casper committed
42
43
44
45
46
47
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
48
49
50
51
52
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
53
            attention_mask=attention_mask,
54
55
56
        )

        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
57
        out = self.moe.forward(self.norm_2(h))
58
59
60
61
        out = h + out

        return out, None, past_key_value

Casper's avatar
Casper committed
62

Casper's avatar
Casper committed
63
64
65
66
67
class LlamaLikeBlock(nn.Module):
    """
    LlamaLikeBlock is intended to be reused across blocks that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
68

Casper's avatar
Casper committed
69
    def __init__(
Casper's avatar
Casper committed
70
71
72
73
74
75
76
77
78
79
80
81
82
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
        rope_theta=10000,
        use_alibi=False,
Casper's avatar
Casper committed
83
    ):
Casper's avatar
Casper committed
84
85
86
87
88
89
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
90
91
92
93
94
95
96
97
98
            self.hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=use_alibi,
            rope_theta=rope_theta,
Casper's avatar
Casper committed
99
100
101
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.mlp = mlp.to(dev)
102
        self.device = dev
Casper's avatar
Casper committed
103
104

    def forward(
Casper's avatar
Casper committed
105
106
107
108
109
110
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper's avatar
Casper committed
111
112
113
114
115
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
Casper's avatar
Casper committed
116
            attention_mask=attention_mask,
Casper's avatar
Casper committed
117
118
        )

119
        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
120
121
122
123
        out = h + self.mlp.forward(self.norm_2(h))

        return out, None, past_key_value

Casper's avatar
Casper committed
124

125
class MPTBlock(nn.Module):
Casper's avatar
Casper committed
126
127
128
129
130
131
132
133
134
135
136
137
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mpt_mlp,
        norm_1,
        norm_2,
        dev,
        max_seq_len,
    ):
Casper Hansen's avatar
Casper Hansen committed
138
139
        super().__init__()
        self.n_heads = n_heads
Casper Hansen's avatar
Casper Hansen committed
140
        self.n_kv_heads = 0
Casper Hansen's avatar
Casper Hansen committed
141
        self.hidden_size = hidden_size
142
        self.norm_1 = norm_1
Casper Hansen's avatar
Casper Hansen committed
143
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
144
145
146
147
148
149
150
151
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=True,
Casper Hansen's avatar
Casper Hansen committed
152
        ).to(dev)
153
154
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
155
        self.device = dev
Casper Hansen's avatar
Casper Hansen committed
156
157

    def forward(
Casper's avatar
Casper committed
158
159
160
161
162
163
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
Casper Hansen's avatar
Casper Hansen committed
164
165
166
167
168
169
170
171
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
172
            use_cache=True,
Casper Hansen's avatar
Casper Hansen committed
173
174
        )

175
        h = hidden_states.to(attn_output.device) + attn_output
Casper Hansen's avatar
Casper Hansen committed
176
        out = h + self.ffn.forward(self.norm_2(h))
177
178
        return out, None, past_key_value

Casper's avatar
Casper committed
179

180
class FalconDecoderLayer(nn.Module):
Casper's avatar
Casper committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    def __init__(
        self,
        hidden_size,
        n_heads,
        qkv_layer,
        o_proj,
        mlp,
        dev,
        max_seq_len,
        input_layernorm=None,
        ln_attn=None,
        ln_mlp=None,
        new_decoder_arch=True,
    ):
195
196
        super().__init__()
        self.n_heads = n_heads
197
        self.n_kv_heads = 8 if new_decoder_arch else 0
198
199
        self.hidden_size = hidden_size
        self.new_decoder_arch = new_decoder_arch
Casper Hansen's avatar
Casper Hansen committed
200
201
202
203

        if new_decoder_arch:
            attention_shapes = None
        else:
Casper's avatar
Casper committed
204
205
206
207
            attention_shapes = self._get_attention_shapes(
                n_heads, max_seq_len, self.hidden_size // n_heads
            )

208
209
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(
Casper's avatar
Casper committed
210
211
212
213
214
215
216
217
218
            hidden_size,
            self.n_heads,
            self.n_kv_heads,
            qkv_layer,
            o_proj,
            dev=dev,
            max_seq_len=max_seq_len,
            use_alibi=False,
            attention_shapes=attention_shapes,
219
        ).to(dev)
Casper's avatar
Casper committed
220

221
        if new_decoder_arch:
Casper's avatar
Casper committed
222
223
            self.ln_attn = ln_attn  # before attention
            self.ln_mlp = ln_mlp  # before mlp
224
        else:
Casper's avatar
Casper committed
225
226
            self.input_layernorm = input_layernorm  # before attention

227
        self.mlp = mlp
228
        self.device = dev
Casper's avatar
Casper committed
229

Casper Hansen's avatar
Casper Hansen committed
230
    def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
231
        batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
232

Casper Hansen's avatar
Casper Hansen committed
233
234
        self.attention_shapes = {
            # following fastertransformer definition
Casper's avatar
Casper committed
235
236
237
238
239
240
            "cache_v": (
                batch_size,
                1,
                max_seq_len,
                head_dim,
            ),
Casper Hansen's avatar
Casper Hansen committed
241
            # 8: pack 8 fp16 in FT, if fp32 then use 4
Casper's avatar
Casper committed
242
243
244
245
246
247
248
249
            "cache_k": (
                batch_size,
                1,
                head_dim // 8,
                max_seq_len,
                8,
            ),
            "xqkv_view": (n_heads + 2, head_dim),
Casper Hansen's avatar
Casper Hansen committed
250
251
252
253
254
255
256
257
258
            "xq_slice": lambda xqkv: xqkv[:, :, :-2],
            "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
            "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
            "xq_view": (n_heads, head_dim),
            "xk_view": (1, head_dim),
            "xv_view": (1, head_dim),
            "xk_reshape": (1, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (1, head_dim),
Casper's avatar
Casper committed
259
            "single_xv_view": (1, head_dim),
Casper Hansen's avatar
Casper Hansen committed
260
        }
261
262

        return self.attention_shapes
263
264

    def forward(
Casper's avatar
Casper committed
265
266
267
268
269
270
        self,
        hidden_states,
        past_key_value,
        attn_bias=None,
        attention_mask=None,
        is_causal=None,
271
272
273
274
275
276
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
Casper's avatar
Casper committed
277

278
279
280
281
282
283
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
Casper's avatar
Casper committed
284
            use_cache=True,
285
286
        )

287
        h_attn = hidden_states.to(attn_output.device) + attn_output
288
289
290
291
292

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
Casper's avatar
Casper committed
293

294
        out = h_attn + h_mlp
Casper's avatar
Casper committed
295
296

        return out, None, past_key_value