attn.py 9.88 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

Casper's avatar
Casper committed
9
10
11
12
13
try:
    import ft_inference_engine
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
14

Casper Hansen's avatar
Casper Hansen committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class RoPE(nn.Module):
    def __init__(self, hidden_size, n_heads, max_seq_len, device):
        super(RoPE, self).__init__()
        
        self.freqs_cis = nn.Parameter(
            self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
            requires_grad=False
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
        
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
        
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
54

Casper Hansen's avatar
Casper Hansen committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
        
        # Initialize ALiBi slopes and bias
        slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
        
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
            
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
    
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
87
88

class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
89
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
90
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
91
92
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
93
94
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
95
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
96
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
97
98
99
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
100
        self.use_alibi = use_alibi
101
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
102
        self.max_seq_len = max_seq_len
Casper Hansen's avatar
Casper Hansen committed
103
104
105
106
107
108
109

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
            attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
110
            self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], self.max_seq_len, dev
Casper Hansen's avatar
Casper Hansen committed
111
        )
Casper Hansen's avatar
Casper Hansen committed
112

113
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
114
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
115
116
117
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
118
119
            self.alibi = None
            self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
120
121
122
            self.rotary_dim = self.head_dim
            self.is_neox = True
    
Casper Hansen's avatar
Casper Hansen committed
123
    def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
Casper Hansen's avatar
Casper Hansen committed
124
        bsz, seqlen, _ = hidden_states.shape
125
126
127
128
129
130
131
132
133
134
135
136

        # Check if we are under transformers caching regime
        has_past_key_value = kwargs is not None and "past_key_value" in kwargs and kwargs["past_key_value"] is not None

        if has_past_key_value:
            # In newest transformers version, when using caching the input hidden states do not consist of 
            # the last generated token only, but of the whole sequence - past-kvlength. We need to slice the last token
            # and set `seqlen=1`
            if seqlen > 1:
                seqlen = 1
                hidden_states = hidden_states[:, -1:]

137
138
139
140
141
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
142

Casper's avatar
Casper committed
143
144
145
        will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len

        # Reset and avoid retaining state when processing context
146
        if will_cache_be_exceeded and seqlen > 1:
Casper's avatar
Casper committed
147
148
149
150
            self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos)
        # Slowly roll out old tokens without performance hit if exceeded during decoding 
        elif will_cache_be_exceeded and seqlen == 1:
            self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=100)
151
            
Casper Hansen's avatar
Casper Hansen committed
152
        xqkv = self.qkv_proj(hidden_states)
153
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
154
        
155
156
157
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
158

Casper's avatar
Casper committed
159
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
160
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
161
162
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
163

164
            if not self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
165
                xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
166

Casper Hansen's avatar
Casper Hansen committed
167
            self.cache.to(xq)
Haotian Tang's avatar
Haotian Tang committed
168

Casper Hansen's avatar
Casper Hansen committed
169
170
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
171
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
172
173
174
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper Hansen's avatar
Casper Hansen committed
175
            
Casper Hansen's avatar
Casper Hansen committed
176
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
177

Casper's avatar
Casper committed
178
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
179
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
180
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
181

Casper's avatar
Casper committed
182
            
Casper Hansen's avatar
Casper Hansen committed
183
184
            keys = xk
            values = xv
185
186
187
188
189

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
                values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
            
Casper Hansen's avatar
Casper Hansen committed
190
191
192
193
194
195
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
196
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
197

198
199
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
Casper Hansen's avatar
Casper Hansen committed
200
201
202
203
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
204
        else:
205
206
207
208
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
209
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
Casper's avatar
Casper committed
210
            attention_weight = ft_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
211
212
213
                xq, # query
                xk, # key
                xv, # value
Casper Hansen's avatar
Casper Hansen committed
214
215
                self.cache.k, # key cache
                self.cache.v, # value cache
Casper Hansen's avatar
Casper Hansen committed
216
                None, # length per sample
Casper Hansen's avatar
Casper Hansen committed
217
                alibi_slopes, # alibi slopes
Casper Hansen's avatar
Casper Hansen committed
218
219
220
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
221
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
222
            )
Casper Hansen's avatar
Casper Hansen committed
223
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
224
        
Casper Hansen's avatar
Casper Hansen committed
225
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
226
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
227

Casper Hansen's avatar
Casper Hansen committed
228
229
        # past_key_value is replaced with cache_v, cache_k, returning empty data
        past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
230
        return attn_output, attention_weight, past_key_value