attn.py 11.1 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
5
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
6
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
7

Casper's avatar
Casper committed
8
9
10
11
12
try:
    import ft_inference_engine
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
13

Casper Hansen's avatar
Casper Hansen committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
39
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper Hansen's avatar
Casper Hansen committed
40
41
42
43
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

Casper Hansen's avatar
Casper Hansen committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def gen_slopes(n_heads, alibi_bias_max=8):
    _n_heads = 2 ** math.ceil(math.log2(n_heads))
    m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / _n_heads)
    slopes = 1.0 / torch.pow(2, m)
    if _n_heads != n_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
    return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
    n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
    if full:
        alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, seq_len, 1
        )
        alibi_bias = alibi_bias.abs().mul(-1)
    slopes = gen_slopes(n_heads, alibi_bias_max)
    alibi_bias = alibi_bias * slopes
    slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
    return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)

Haotian Tang's avatar
Haotian Tang committed
68

Casper Hansen's avatar
Casper Hansen committed
69
class QuantAttentionFused(nn.Module):
Casper Hansen's avatar
Casper Hansen committed
70
    def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, 
71
                       use_alibi=False, attention_shapes=None):
Casper Hansen's avatar
Casper Hansen committed
72
73
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
74
75
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
76
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
77
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
78
79
80
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
81
        self.use_alibi = use_alibi
82
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
83
84
        self.max_seq_len = max_seq_len
        self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
85
86
        self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
        self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )
Casper Hansen's avatar
Casper Hansen committed
87

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        if use_alibi:
            alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
            self.alibi_slopes = alibi_slopes.float().to(dev)
            self.alibi_bias = alibi_bias.float().to(dev)
            self.rotary_dim = 0
            self.is_neox = False
        else:
            self.freqs_cis = precompute_freqs_cis(
                hidden_size // n_heads,
                max_seq_len * 2,
            ).to(dev)
            self.rotary_dim = self.head_dim
            self.alibi_slopes = None
            self.is_neox = True
    
    def _get_attention_shapes(self, attention_shapes, max_seq_len):
Casper Hansen's avatar
Casper Hansen committed
104
        if attention_shapes is not None:
105
            attention_shapes = attention_shapes
Casper Hansen's avatar
Casper Hansen committed
106
107

        elif self.n_kv_heads == 0:
108
            attention_shapes = {
Casper Hansen's avatar
Casper Hansen committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (-1, self.n_heads, self.head_dim),
                "xq_slice": lambda xqkv: xqkv[:, :, 0],
                "xk_slice": lambda xqkv: xqkv[:, :, 1],
                "xv_slice": lambda xqkv: xqkv[:, :, 2],
                "xq_view": (self.n_heads, self.head_dim),
                "xk_view": (self.n_heads, self.head_dim),
                "xv_view": (self.n_heads, self.head_dim),
                "xk_reshape": (self.n_heads, self.head_dim // 8, 8),
                "single_xq_view": (self.n_heads, self.head_dim),
                "single_xk_view": (self.n_heads, self.head_dim),
                "single_xv_view": (self.n_heads, self.head_dim)
            }

        else:
127
            attention_shapes = {
Casper Hansen's avatar
Casper Hansen committed
128
129
130
131
132
                # following fastertransformer definition
                "cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
                # 8: pack 8 fp16 in FT, if fp32 then use 4
                "cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
                "xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
133
                "xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
Casper Hansen's avatar
Casper Hansen committed
134
135
                "xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
                "xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
136
                "xq_view": (self.n_heads, self.head_dim),
Casper Hansen's avatar
Casper Hansen committed
137
138
139
                "xk_view": (self.n_kv_heads, self.head_dim),
                "xv_view": (self.n_kv_heads, self.head_dim),
                "xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
140
                "single_xq_view": (self.n_heads, self.head_dim),
Casper Hansen's avatar
Casper Hansen committed
141
142
143
                "single_xk_view": (self.n_kv_heads, self.head_dim),
                "single_xv_view": (self.n_kv_heads, self.head_dim)
            }
144
        
145
        return attention_shapes
146
    
Casper Hansen's avatar
Casper Hansen committed
147
148
    def forward(
        self,
Casper Hansen's avatar
Casper Hansen committed
149
150
        hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, 
        output_attentions=False, use_cache=False, *args, **kwargs
Casper Hansen's avatar
Casper Hansen committed
151
152
    ):
        bsz, seqlen, _ = hidden_states.shape
153
154
155
156
157
        if bsz != self.cache_batch_size:
            raise RuntimeError(
                f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
                f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
            )
158

Casper Hansen's avatar
Casper Hansen committed
159
        if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
160
161
162
163
164
165
166
            # Roll cache to the left
            roll_len = self.start_pos
            self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
            self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
            # Zero out the new part
            self.cache_v[:, :, -roll_len:, :] = 0
            self.cache_k[:, :, :, -roll_len:, :] = 0
167
168
            self.start_pos = 0
            
Casper Hansen's avatar
Casper Hansen committed
169
        xqkv = self.qkv_proj(hidden_states)
170
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper Hansen's avatar
Casper Hansen committed
171
        
172
173
174
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
175

Casper's avatar
Casper committed
176
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
177
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
178
179
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
180

181
182
            if not self.use_alibi:
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
Haotian Tang's avatar
Haotian Tang committed
183

Casper Hansen's avatar
Casper Hansen committed
184
185
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
186

Casper Hansen's avatar
Casper Hansen committed
187
188
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
189
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
190
191
192
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper Hansen's avatar
Casper Hansen committed
193
            
Casper Hansen's avatar
Casper Hansen committed
194
195
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
Casper Hansen's avatar
Casper Hansen committed
196

qwopqwop200's avatar
fix bug  
qwopqwop200 committed
197
198
199
200
            if seqlen == 1:
                xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
                xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
                xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
Casper's avatar
Casper committed
201
            
Casper Hansen's avatar
Casper Hansen committed
202
203
            keys = xk
            values = xv
204
205
206
207
208

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
                values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
            
Casper Hansen's avatar
Casper Hansen committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
                scores += self.alibi_bias[..., :seqlen]

            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
                
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
223
        else:
224
225
226
227
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper's avatar
Casper committed
228
            attention_weight = ft_inference_engine.single_query_attention(
Casper Hansen's avatar
Casper Hansen committed
229
230
231
232
233
234
235
236
237
238
                xq, # query
                xk, # key
                xv, # value
                self.cache_k, # key cache
                self.cache_v, # value cache
                None, # length per sample
                self.alibi_slopes, # alibi slopes
                self.start_pos, # timestep
                self.rotary_dim, # rotary embedding dimension
                10000, # rotary embedding base
239
                self.is_neox, # is neox
Casper Hansen's avatar
Casper Hansen committed
240
            )
Casper Hansen's avatar
Casper Hansen committed
241
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper Hansen's avatar
Casper Hansen committed
242
        
Casper Hansen's avatar
Casper Hansen committed
243
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
244
245
246
247
248
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
249

250
251
        # past_key_value is replaced with cache_v, cache_k, returning None
        return attn_output, attention_weight, None