attn.py 12 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

9

Casper's avatar
Casper committed
10
try:
11
    import awq_ft_ext
Casper's avatar
Casper committed
12

Casper's avatar
Casper committed
13
14
15
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
16

17
18
19
HF_NEW_CACHE_FORMAT = False

import transformers
Casper's avatar
Casper committed
20

21
22
23
24
25
26
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
    from transformers.cache_utils import DynamicCache


Casper Hansen's avatar
Casper Hansen committed
27
class RoPE(nn.Module):
TechxGenus's avatar
TechxGenus committed
28
    def __init__(self, head_dim, max_seq_len, device, rope_theta):
Casper Hansen's avatar
Casper Hansen committed
29
        super(RoPE, self).__init__()
Casper's avatar
Casper committed
30

Casper Hansen's avatar
Casper Hansen committed
31
        self.freqs_cis = nn.Parameter(
Isotr0py's avatar
Isotr0py committed
32
            self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
Casper's avatar
Casper committed
33
            requires_grad=False,
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper's avatar
Casper committed
61

Casper Hansen's avatar
Casper Hansen committed
62
63
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
Casper's avatar
Casper committed
64

Casper Hansen's avatar
Casper Hansen committed
65
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
66

Casper's avatar
Casper committed
67

Casper Hansen's avatar
Casper Hansen committed
68
69
70
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
Casper's avatar
Casper committed
71

Casper Hansen's avatar
Casper Hansen committed
72
        # Initialize ALiBi slopes and bias
Casper's avatar
Casper committed
73
74
75
        slopes, bias = self.build_alibi_bias(
            n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
        )
Casper Hansen's avatar
Casper Hansen committed
76
77
78
79
80
81
82
83
84
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
Casper's avatar
Casper committed
85

Casper Hansen's avatar
Casper Hansen committed
86
87
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
Casper's avatar
Casper committed
88

Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
Casper's avatar
Casper committed
93
94
95
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, 1, seq_len
        )
Casper Hansen's avatar
Casper Hansen committed
96
97
98
99
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
Casper's avatar
Casper committed
100

Casper Hansen's avatar
Casper Hansen committed
101
102
103
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
104

Casper's avatar
Casper committed
105

Casper Hansen's avatar
Casper Hansen committed
106
class QuantAttentionFused(nn.Module):
Casper's avatar
Casper committed
107
108
109
110
111
112
113
114
115
116
117
118
    def __init__(
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        dev,
        max_seq_len=2048,
        use_alibi=False,
        attention_shapes=None,
        rope_theta=10000,
Isotr0py's avatar
Isotr0py committed
119
        partial_rotary_factor=1.0,
TechxGenus's avatar
TechxGenus committed
120
        head_dim=None,
Casper's avatar
Casper committed
121
122
        **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
123
124
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
125
126
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
127
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
TechxGenus's avatar
TechxGenus committed
128
        self.head_dim = head_dim
Isotr0py's avatar
Isotr0py committed
129

TechxGenus's avatar
TechxGenus committed
130
131
132
        if head_dim is None:
            self.head_dim = hidden_size // n_heads

Casper Hansen's avatar
Casper Hansen committed
133
134
135
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
136
        self.use_alibi = use_alibi
137
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
138
139
140
141

        if kwargs.get("max_new_tokens") is not None:
            max_seq_len = kwargs["max_new_tokens"]

142
        self.max_seq_len = max_seq_len
143
        self.is_hf_transformers = False
Casper's avatar
Casper committed
144
        self.rope_theta = rope_theta
Casper Hansen's avatar
Casper Hansen committed
145
146
147

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
Casper's avatar
Casper committed
148
149
150
151
152
153
            attention_shapes,
            max_seq_len,
            self.cache_batch_size,
            n_heads,
            n_kv_heads,
            self.head_dim,
Casper Hansen's avatar
Casper Hansen committed
154
155
156
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
157
158
159
160
            self.attention_shapes["cache_v"],
            self.attention_shapes["cache_k"],
            self.max_seq_len,
            dev,
Casper Hansen's avatar
Casper Hansen committed
161
        )
Casper Hansen's avatar
Casper Hansen committed
162

163
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
164
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
165
166
167
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
168
            self.alibi = None
Isotr0py's avatar
Isotr0py committed
169
170
171
            self.partial_rotary_factor = partial_rotary_factor
            self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
            self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
172
            self.is_neox = True
Casper's avatar
Casper committed
173
174
175
176

    def forward(
        self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
177
        bsz, seqlen, _ = hidden_states.shape
178

Casper's avatar
Casper committed
179
        # Reallocate cache if batch size changes
180
        if bsz != self.cache_batch_size:
Casper's avatar
Casper committed
181
182
183
184
185
186
            if bsz > self.cache_batch_size:
                self.cache.increase_batch_size(bsz)
                self.cache_batch_size = bsz
            elif bsz < self.cache_batch_size:
                self.cache.decrease_batch_size(bsz)
                self.cache_batch_size = bsz
187
188

            # Always reset to 0
Casper's avatar
Casper committed
189
            self.start_pos = 0
190

191
192
193
194
195
196
        hf_is_generating = False

        if self.is_hf_transformers and "use_cache" in kwargs:
            hf_is_generating = kwargs["use_cache"]


Casper's avatar
Casper committed
197
198
        # In case we re-generate, we need to refresh the starting position
        # to 0. We detect it by checking if `past_key_values` is set to None,
199
        # which indicates that we are on the first step of `generate()`.
200
        # This is only applicable for `transformers` integration
201
        if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating):
202
            self.start_pos = 0
203
    
204

Casper Hansen's avatar
Casper Hansen committed
205
        xqkv = self.qkv_proj(hidden_states)
206
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper's avatar
Casper committed
207

208
209
210
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
211

Isotr0py's avatar
Isotr0py committed
212
        if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
213
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
214
215
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
216

217
            if not self.use_alibi:
Isotr0py's avatar
Isotr0py committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
                # Partial rotary embedding
                if self.partial_rotary_factor < 1:
                    xq_rot, xq_pass = (
                        xq[..., : self.rotary_dim],
                        xq[..., self.rotary_dim :],
                    )
                    xk_rot, xk_pass = (
                        xk[..., : self.rotary_dim],
                        xk[..., self.rotary_dim :],
                    )
                    xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
                    xq = torch.cat((xq_rot, xq_pass), dim=-1)
                    xk = torch.cat((xk_rot, xk_pass), dim=-1)
                else:
                    xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
233

Casper Hansen's avatar
Casper Hansen committed
234
235
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
236
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
237
238
239
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper's avatar
Casper committed
240

241
            self.cache.to(xq)
Casper Hansen's avatar
Casper Hansen committed
242
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
243

Casper's avatar
Casper committed
244
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
245
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
246
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
247

Casper Hansen's avatar
Casper Hansen committed
248
249
            keys = xk
            values = xv
250
251
252

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
Casper's avatar
Casper committed
253
254
255
256
                values = torch.repeat_interleave(
                    values, dim=2, repeats=self.n_kv_groups
                )

Casper Hansen's avatar
Casper Hansen committed
257
258
259
260
261
262
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
263
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
264

265
266
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
267
268
269
270
271
                # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we 
                # need to slice it
                if attention_mask.shape[-1] != seqlen:
                    attention_mask = attention_mask[:, :, :seqlen, :seqlen]

Casper's avatar
Casper committed
272
273
274
                scores = (
                    scores + attention_mask
                )  # (bs, n_local_heads, slen, cache_len + slen)
Casper Hansen's avatar
Casper Hansen committed
275
276
277
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
278
        else:
279
280
281
282
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
283
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
284
            attention_weight = awq_ft_ext.single_query_attention(
Casper's avatar
Casper committed
285
286
287
288
289
290
291
292
293
294
295
                xq,  # query
                xk,  # key
                xv,  # value
                self.cache.k,  # key cache
                self.cache.v,  # value cache
                None,  # length per sample
                alibi_slopes,  # alibi slopes
                self.start_pos,  # timestep
                self.rotary_dim,  # rotary embedding dimension
                self.rope_theta,  # rotary embedding base
                self.is_neox,  # is neox
Casper Hansen's avatar
Casper Hansen committed
296
            )
Casper Hansen's avatar
Casper Hansen committed
297
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper's avatar
Casper committed
298

Casper Hansen's avatar
Casper Hansen committed
299
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
300
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
301

302
303
304
        if self.is_hf_transformers and not hf_is_generating:
            self.start_pos = 0

Casper Hansen's avatar
Casper Hansen committed
305
        # past_key_value is replaced with cache_v, cache_k, returning empty data
Casper's avatar
Casper committed
306
        # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
307
308
        # about past key length
        past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
309

310

311
312
313
314
315
        if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
            new_cache = DynamicCache()
            new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
            past_key_value = new_cache

316
        return attn_output, attention_weight, past_key_value