attn.py 10.5 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
5
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
6
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
7

Casper's avatar
Casper committed
8
9
10
11
12
try:
    import ft_inference_engine
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
13

Casper Hansen's avatar
Casper Hansen committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def gen_slopes(n_heads, alibi_bias_max=8):
    _n_heads = 2 ** math.ceil(math.log2(n_heads))
    m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / _n_heads)
    slopes = 1.0 / torch.pow(2, m)
    if _n_heads != n_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
    return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
    n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
    if full:
        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, seq_len, 1
        )
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max)
    alibi_bias = alibi_bias * slopes
    slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
    return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)

Haotian Tang's avatar
Haotian Tang committed
68

Casper Hansen's avatar
Casper Hansen committed
69
class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
70
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
71
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
72
73
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
74
75
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
76
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
77
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
78
79
80
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
81
        self.use_alibi = use_alibi
82
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper Hansen's avatar
Casper Hansen committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112

        if attention_shapes is not None:
            self.attention_shapes = attention_shapes

        elif self.n_kv_heads == 0:
            self.attention_shapes = {
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (-1, self.n_heads, self.head_dim),
                "xq_slice": lambda xqkv: xqkv[:, :, 0],
                "xk_slice": lambda xqkv: xqkv[:, :, 1],
                "xv_slice": lambda xqkv: xqkv[:, :, 2],
                "xq_view": (self.n_heads, self.head_dim),
                "xk_view": (self.n_heads, self.head_dim),
                "xv_view": (self.n_heads, self.head_dim),
                "xk_reshape": (self.n_heads, self.head_dim // 8, 8),
                "single_xq_view": (self.n_heads, self.head_dim),
                "single_xk_view": (self.n_heads, self.head_dim),
                "single_xv_view": (self.n_heads, self.head_dim)
            }

        else:
            self.attention_shapes = {
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
113
                "xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
Casper Hansen's avatar
Casper Hansen committed
114
115
                "xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
                "xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
116
                "xq_view": (self.n_heads, self.head_dim),
Casper Hansen's avatar
Casper Hansen committed
117
118
119
                "xk_view": (self.n_kv_heads, self.head_dim),
                "xv_view": (self.n_kv_heads, self.head_dim),
                "xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
120
                "single_xq_view": (self.n_heads, self.head_dim),
Casper Hansen's avatar
Casper Hansen committed
121
122
123
                "single_xk_view": (self.n_kv_heads, self.head_dim),
                "single_xv_view": (self.n_kv_heads, self.head_dim)
            }
Casper Hansen's avatar
Casper Hansen committed
124

Casper Hansen's avatar
Casper Hansen committed
125
        self.cache_v = (
126
            torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
127
128
        )
        
Casper Hansen's avatar
Casper Hansen committed
129
        self.cache_k = (
130
            torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
131
        )
132

Casper Hansen's avatar
Casper Hansen committed
133
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
134
            alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
Casper Hansen's avatar
Casper Hansen committed
135
136
137
            self.alibi_slopes = alibi_slopes.float().to(dev)
            self.alibi_bias = alibi_bias.float().to(dev)
            self.rotary_dim = 0
138
            self.is_neox = False
Casper Hansen's avatar
Casper Hansen committed
139
140
        else:
            self.freqs_cis = precompute_freqs_cis(
Casper Hansen's avatar
Casper Hansen committed
141
                hidden_size // n_heads,
Casper Hansen's avatar
Casper Hansen committed
142
143
                max_seq_len * 2,
            ).to(dev)
144
            self.rotary_dim = self.head_dim
Casper Hansen's avatar
Casper Hansen committed
145
            self.alibi_slopes = None
146
            self.is_neox = True
147
    
Casper Hansen's avatar
Casper Hansen committed
148
149
150
151
152
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
153
154
155
156
157
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
Casper Hansen's avatar
Casper Hansen committed
158
        xqkv = self.qkv_proj(hidden_states)
159
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
160
        
161
162
163
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
164

Casper's avatar
Casper committed
165
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
166
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
167
168
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
169

170
171
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
172

Casper Hansen's avatar
Casper Hansen committed
173
174
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
175

Casper Hansen's avatar
Casper Hansen committed
176
177
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
178
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
179
180
181
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
182

Casper Hansen's avatar
Casper Hansen committed
183
184
185
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

qwopqwop200's avatar
fix bug  
qwopqwop200 committed
186
187
188
189
            if seqlen == 1:
                xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
                xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
                xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
Casper's avatar
Casper committed
190
            
Casper Hansen's avatar
Casper Hansen committed
191
192
            keys = xk
            values = xv
193
194
195
196
197

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
                values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
            
Casper Hansen's avatar
Casper Hansen committed
198
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
                scores += self.alibi_bias[..., :seqlen]

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
213
        else:
214
215
216
217
218
219
220
            # xq = xq[:, 0, :, :]
            # xk = xk[:, 0, :, :]
            # xv = xv[:, 0, :, :]
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
221
            past_key_value = (xk, xv) if use_cache else None
Casper Hansen's avatar
Casper Hansen committed
222
            attention_weight = awq_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
223
224
225
226
227
228
229
230
231
232
                xq, # query
                xk, # key
                xv, # value
                self.cache_k, # key cache
                self.cache_v, # value cache
                None, # length per sample
                self.alibi_slopes, # alibi slopes
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
233
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
234
            )
Casper Hansen's avatar
Casper Hansen committed
235
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
236
        
Casper Hansen's avatar
Casper Hansen committed
237
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
238
239
240
241
242
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
243

qwopqwop200's avatar
fix bug  
qwopqwop200 committed
244
        return attn_output, attention_weight, past_key_value