attn.py 8.46 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

Casper's avatar
Casper committed
9
10
11
12
13
try:
    import ft_inference_engine
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
14

Casper Hansen's avatar
Casper Hansen committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

Casper Hansen's avatar
Casper Hansen committed
29
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
Casper Hansen's avatar
Casper Hansen committed
30
31
32
33
34
35
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
36
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper Hansen's avatar
Casper Hansen committed
37
38
39
40
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
        
        # Initialize ALiBi slopes and bias
        slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
        
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
            
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
    
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
73
74

class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
75
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
76
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
77
78
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
79
80
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
81
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
82
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
83
84
85
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
86
        self.use_alibi = use_alibi
87
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
88
        self.max_seq_len = max_seq_len
Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
93
94
95
96
97

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
            attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
            self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
        )
Casper Hansen's avatar
Casper Hansen committed
98

99
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
100
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
101
102
103
104
105
106
107
108
109
110
111
            self.rotary_dim = 0
            self.is_neox = False
        else:
            self.freqs_cis = precompute_freqs_cis(
                hidden_size // n_heads,
                max_seq_len * 2,
            ).to(dev)
            self.rotary_dim = self.head_dim
            self.alibi_slopes = None
            self.is_neox = True
    
Casper Hansen's avatar
Casper Hansen committed
112
    def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
Casper Hansen's avatar
Casper Hansen committed
113
        bsz, seqlen, _ = hidden_states.shape
114
115
116
117
118
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
119

Casper Hansen's avatar
Casper Hansen committed
120
        if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
Casper Hansen's avatar
Casper Hansen committed
121
122
            excess_length = self.start_pos + seqlen - self.max_seq_len
            self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
123
            
Casper Hansen's avatar
Casper Hansen committed
124
        xqkv = self.qkv_proj(hidden_states)
125
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
126
        
127
128
129
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
130

Casper's avatar
Casper committed
131
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
132
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
133
134
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
135

136
137
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
138

Casper Hansen's avatar
Casper Hansen committed
139
            self.cache.to(xq)
Haotian Tang's avatar
Haotian Tang committed
140

Casper Hansen's avatar
Casper Hansen committed
141
142
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
143
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
144
145
146
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper Hansen's avatar
Casper Hansen committed
147
            
Casper Hansen's avatar
Casper Hansen committed
148
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
149

qwopqwop200's avatar
fix bug  
qwopqwop200 committed
150
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
151
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
Casper's avatar
Casper committed
152
            
Casper Hansen's avatar
Casper Hansen committed
153
154
            keys = xk
            values = xv
155
156
157
158
159

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
                values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
            
Casper Hansen's avatar
Casper Hansen committed
160
161
162
163
164
165
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
166
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
167
168
169
170
171
172
173

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
174
        else:
175
176
177
178
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper's avatar
Casper committed
179
            attention_weight = ft_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
180
181
182
                xq, # query
                xk, # key
                xv, # value
Casper Hansen's avatar
Casper Hansen committed
183
184
                self.cache.k, # key cache
                self.cache.v, # value cache
Casper Hansen's avatar
Casper Hansen committed
185
                None, # length per sample
Casper Hansen's avatar
Casper Hansen committed
186
                self.alibi.slopes, # alibi slopes
Casper Hansen's avatar
Casper Hansen committed
187
188
189
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
190
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
191
            )
Casper Hansen's avatar
Casper Hansen committed
192
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
193
        
Casper Hansen's avatar
Casper Hansen committed
194
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
195
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
196

Casper Hansen's avatar
Casper Hansen committed
197
198
199
        # past_key_value is replaced with cache_v, cache_k, returning empty data
        past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
        return attn_output, attention_weight, past_key_value