block.py 6.85 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
3
4
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused

5
6
7
class MixtralBlock(nn.Module):
    def __init__(
        self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, 
Casper's avatar
Casper committed
8
        moe, norm_1, norm_2, dev, max_seq_len, rope_theta
9
10
11
12
13
14
15
16
    ):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
            self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
Casper's avatar
Casper committed
17
            dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.moe = moe
        self.device = dev

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask
        )

        h = hidden_states.to(attn_output.device) + attn_output
        out, _ = self.moe.forward(self.norm_2(h))
        out = h + out

        return out, None, past_key_value

Casper's avatar
Casper committed
39
40
41
42
43
class LlamaLikeBlock(nn.Module):
    """
    LlamaLikeBlock is intended to be reused across blocks that have
    an architecture that closely resembles Llama, e.g. Mistral and Aquila.
    """
Casper's avatar
Casper committed
44
45
    def __init__(
        self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, 
46
        mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False
Casper's avatar
Casper committed
47
    ):
Casper's avatar
Casper committed
48
49
50
51
52
53
54
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.hidden_size = hidden_size
        self.norm_1 = norm_1.to(dev)
        self.attn = QuantAttentionFused(
            self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
Aoyu's avatar
Aoyu committed
55
            dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta
Casper's avatar
Casper committed
56
57
58
        ).to(dev)
        self.norm_2 = norm_2.to(dev)
        self.mlp = mlp.to(dev)
59
        self.device = dev
Casper's avatar
Casper committed
60
61
62
63
64
65
66
67
68
69
70

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask
        )

71
        h = hidden_states.to(attn_output.device) + attn_output
Casper's avatar
Casper committed
72
73
74
75
        out = h + self.mlp.forward(self.norm_2(h))

        return out, None, past_key_value

76
77
class MPTBlock(nn.Module):
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
Casper Hansen's avatar
Casper Hansen committed
78
79
        super().__init__()
        self.n_heads = n_heads
Casper Hansen's avatar
Casper Hansen committed
80
        self.n_kv_heads = 0
Casper Hansen's avatar
Casper Hansen committed
81
        self.hidden_size = hidden_size
82
        self.norm_1 = norm_1
Casper Hansen's avatar
Casper Hansen committed
83
84
85
86
        self.attn = QuantAttentionFused(
            hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 
            dev=dev, max_seq_len=max_seq_len, use_alibi=True
        ).to(dev)
87
88
        self.norm_2 = norm_2
        self.ffn = mpt_mlp.to(dev)
89
        self.device = dev
Casper Hansen's avatar
Casper Hansen committed
90
91

    def forward(
92
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
Casper Hansen's avatar
Casper Hansen committed
93
94
95
96
97
98
99
100
101
102
103
    ):
        norm_out = self.norm_1(hidden_states)
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=norm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

104
        h = hidden_states.to(attn_output.device) + attn_output
Casper Hansen's avatar
Casper Hansen committed
105
        out = h + self.ffn.forward(self.norm_2(h))
106
107
108
        return out, None, past_key_value

class FalconDecoderLayer(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
109
110
    def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, 
                       input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
111
112
        super().__init__()
        self.n_heads = n_heads
113
        self.n_kv_heads = 8 if new_decoder_arch else 0
114
115
        self.hidden_size = hidden_size
        self.new_decoder_arch = new_decoder_arch
Casper Hansen's avatar
Casper Hansen committed
116
117
118
119
120

        if new_decoder_arch:
            attention_shapes = None
        else:
            attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
121
122
123
        
        # TODO: Falcon has ALiBi implemented but which model uses it?
        self.attn = QuantAttentionFused(
Casper Hansen's avatar
Casper Hansen committed
124
            hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, 
125
126
127
128
            dev=dev, max_seq_len=max_seq_len, use_alibi=False,
            attention_shapes=attention_shapes
        ).to(dev)
        
129
130
131
132
133
134
135
        if new_decoder_arch:
            self.ln_attn = ln_attn # before attention
            self.ln_mlp = ln_mlp # before mlp
        else:
            self.input_layernorm = input_layernorm # before attention
        
        self.mlp = mlp
136
        self.device = dev
137
    
Casper Hansen's avatar
Casper Hansen committed
138
    def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
139
140
        batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
        
Casper Hansen's avatar
Casper Hansen committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        self.attention_shapes = {
            # following fastertransformer definition
            "cache_v": (batch_size, 1, max_seq_len, head_dim,),
            # 8: pack 8 fp16 in FT, if fp32 then use 4
            "cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
            "xqkv_view": (n_heads+2, head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, :-2],
            "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
            "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
            "xq_view": (n_heads, head_dim),
            "xk_view": (1, head_dim),
            "xv_view": (1, head_dim),
            "xk_reshape": (1, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (1, head_dim),
            "single_xv_view": (1, head_dim)
        }
158
159

        return self.attention_shapes
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

    def forward(
        self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
    ):
        if self.new_decoder_arch:
            layernorm_out = self.ln_attn(hidden_states)
            mlp_layernorm_out = self.ln_mlp(hidden_states)
        else:
            layernorm_out = self.input_layernorm(hidden_states)
        
        attn_output, _, past_key_value = self.attn.forward(
            hidden_states=layernorm_out,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            position_ids=None,
            output_attentions=False,
            use_cache=True
        )

179
        h_attn = hidden_states.to(attn_output.device) + attn_output
180
181
182
183
184
185
186
187

        if self.new_decoder_arch:
            h_mlp = self.mlp.forward(mlp_layernorm_out)
        else:
            h_mlp = self.mlp.forward(layernorm_out)
        
        out = h_attn + h_mlp
        
Aoyu's avatar
Aoyu committed
188
        return out, None, past_key_value