attn.py 10.5 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
Casper Hansen's avatar
Casper Hansen committed
7

Casper's avatar
Casper committed
8
9
10
11
12
try:
    import ft_inference_engine
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
13

Casper Hansen's avatar
Casper Hansen committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

Casper Hansen's avatar
Casper Hansen committed
28
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
Casper Hansen's avatar
Casper Hansen committed
29
30
31
32
33
34
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
35
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
        
        # Initialize ALiBi slopes and bias
        slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
        
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
            
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
    
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
72

Casper Hansen's avatar
Casper Hansen committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
    if attention_shapes is not None:
        attention_shapes = attention_shapes

    elif n_kv_heads == 0:
        attention_shapes = {
            # following fastertransformer definition
            "cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
            # 8: pack 8 fp16 in FT, if fp32 then use 4
            "cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
            "xqkv_view": (-1, n_heads, head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, 0],
            "xk_slice": lambda xqkv: xqkv[:, :, 1],
            "xv_slice": lambda xqkv: xqkv[:, :, 2],
            "xq_view": (n_heads, head_dim),
            "xk_view": (n_heads, head_dim),
            "xv_view": (n_heads, head_dim),
            "xk_reshape": (n_heads, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (n_heads, head_dim),
            "single_xv_view": (n_heads, head_dim)
        }

    else:
        attention_shapes = {
            # following fastertransformer definition
            "cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
            # 8: pack 8 fp16 in FT, if fp32 then use 4
            "cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
            "xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
            "xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
            "xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
            "xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
            "xq_view": (n_heads, head_dim),
            "xk_view": (n_kv_heads, head_dim),
            "xv_view": (n_kv_heads, head_dim),
            "xk_reshape": (n_kv_heads, head_dim // 8, 8),
            "single_xq_view": (n_heads, head_dim),
            "single_xk_view": (n_kv_heads, head_dim),
            "single_xv_view": (n_kv_heads, head_dim)
        }
    
    return attention_shapes
Haotian Tang's avatar
Haotian Tang committed
116

Casper Hansen's avatar
Casper Hansen committed
117
class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
118
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
119
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
120
121
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
122
123
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
124
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
125
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
126
127
128
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
129
        self.use_alibi = use_alibi
130
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
131
        self.max_seq_len = max_seq_len
Casper Hansen's avatar
Casper Hansen committed
132
133
134
135
136
137
138
139
140

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
            attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
            self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
        )
Casper Hansen's avatar
Casper Hansen committed
141

142
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
143
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
144
145
146
147
148
149
150
151
152
153
154
            self.rotary_dim = 0
            self.is_neox = False
        else:
            self.freqs_cis = precompute_freqs_cis(
                hidden_size // n_heads,
                max_seq_len * 2,
            ).to(dev)
            self.rotary_dim = self.head_dim
            self.alibi_slopes = None
            self.is_neox = True
    
Casper Hansen's avatar
Casper Hansen committed
155
    def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
Casper Hansen's avatar
Casper Hansen committed
156
        bsz, seqlen, _ = hidden_states.shape
157
158
159
160
161
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
162

Casper Hansen's avatar
Casper Hansen committed
163
        if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
Casper Hansen's avatar
Casper Hansen committed
164
165
            excess_length = self.start_pos + seqlen - self.max_seq_len
            self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
166
            
Casper Hansen's avatar
Casper Hansen committed
167
        xqkv = self.qkv_proj(hidden_states)
168
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
169
        
170
171
172
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
173

Casper's avatar
Casper committed
174
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
175
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
176
177
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
178

179
180
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
181

Casper Hansen's avatar
Casper Hansen committed
182
            self.cache.to(xq)
Haotian Tang's avatar
Haotian Tang committed
183

Casper Hansen's avatar
Casper Hansen committed
184
185
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
186
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
187
188
189
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper Hansen's avatar
Casper Hansen committed
190
            
Casper Hansen's avatar
Casper Hansen committed
191
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
192

qwopqwop200's avatar
fix bug  
qwopqwop200 committed
193
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
194
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
Casper's avatar
Casper committed
195
            
Casper Hansen's avatar
Casper Hansen committed
196
197
            keys = xk
            values = xv
198
199
200
201
202

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
                values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
            
Casper Hansen's avatar
Casper Hansen committed
203
204
205
206
207
208
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
209
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
210
211
212
213
214
215
216

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
217
        else:
218
219
220
221
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper's avatar
Casper committed
222
            attention_weight = ft_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
223
224
225
                xq, # query
                xk, # key
                xv, # value
Casper Hansen's avatar
Casper Hansen committed
226
227
                self.cache.k, # key cache
                self.cache.v, # value cache
Casper Hansen's avatar
Casper Hansen committed
228
                None, # length per sample
Casper Hansen's avatar
Casper Hansen committed
229
                self.alibi.slopes, # alibi slopes
Casper Hansen's avatar
Casper Hansen committed
230
231
232
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
233
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
234
            )
Casper Hansen's avatar
Casper Hansen committed
235
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
236
        
Casper Hansen's avatar
Casper Hansen committed
237
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
238
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
239

Casper Hansen's avatar
Casper Hansen committed
240
241
242
        # past_key_value is replaced with cache_v, cache_k, returning empty data
        past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
        return attn_output, attention_weight, past_key_value