attn.py 12.3 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

9

Casper's avatar
Casper committed
10
try:
11
    import awq_ft_ext
Casper's avatar
Casper committed
12

Casper's avatar
Casper committed
13
14
15
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
16

17
18
19
HF_NEW_CACHE_FORMAT = False

import transformers
Casper's avatar
Casper committed
20

21
22
23
24
25
26
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
    from transformers.cache_utils import DynamicCache


Casper Hansen's avatar
Casper Hansen committed
27
class RoPE(nn.Module):
TechxGenus's avatar
TechxGenus committed
28
    def __init__(self, head_dim, max_seq_len, device, rope_theta):
Casper Hansen's avatar
Casper Hansen committed
29
        super(RoPE, self).__init__()
Casper's avatar
Casper committed
30

Casper Hansen's avatar
Casper Hansen committed
31
        self.freqs_cis = nn.Parameter(
Isotr0py's avatar
Isotr0py committed
32
            self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
Casper's avatar
Casper committed
33
            requires_grad=False,
Casper Hansen's avatar
Casper Hansen committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper's avatar
Casper committed
61

Casper Hansen's avatar
Casper Hansen committed
62
63
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
Casper's avatar
Casper committed
64

Casper Hansen's avatar
Casper Hansen committed
65
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
66

Casper's avatar
Casper committed
67

Casper Hansen's avatar
Casper Hansen committed
68
69
70
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
Casper's avatar
Casper committed
71

Casper Hansen's avatar
Casper Hansen committed
72
        # Initialize ALiBi slopes and bias
Casper's avatar
Casper committed
73
74
75
        slopes, bias = self.build_alibi_bias(
            n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
        )
Casper Hansen's avatar
Casper Hansen committed
76
77
78
79
80
81
82
83
84
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
Casper's avatar
Casper committed
85

Casper Hansen's avatar
Casper Hansen committed
86
87
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
Casper's avatar
Casper committed
88

Casper Hansen's avatar
Casper Hansen committed
89
90
91
92
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
Casper's avatar
Casper committed
93
94
95
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, 1, seq_len
        )
Casper Hansen's avatar
Casper Hansen committed
96
97
98
99
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
Casper's avatar
Casper committed
100

Casper Hansen's avatar
Casper Hansen committed
101
102
103
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
104

Casper's avatar
Casper committed
105

Casper Hansen's avatar
Casper Hansen committed
106
class QuantAttentionFused(nn.Module):
Casper's avatar
Casper committed
107
108
109
110
111
112
113
114
115
116
117
118
    def __init__(
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        dev,
        max_seq_len=2048,
        use_alibi=False,
        attention_shapes=None,
        rope_theta=10000,
Isotr0py's avatar
Isotr0py committed
119
        partial_rotary_factor=1.0,
TechxGenus's avatar
TechxGenus committed
120
        head_dim=None,
Casper's avatar
Casper committed
121
122
        **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
123
124
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
125
126
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
127
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
TechxGenus's avatar
TechxGenus committed
128
        self.head_dim = head_dim
Isotr0py's avatar
Isotr0py committed
129

TechxGenus's avatar
TechxGenus committed
130
131
132
        if head_dim is None:
            self.head_dim = hidden_size // n_heads

Casper Hansen's avatar
Casper Hansen committed
133
134
135
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
136
        self.use_alibi = use_alibi
137
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
138
139
140
141

        if kwargs.get("max_new_tokens") is not None:
            max_seq_len = kwargs["max_new_tokens"]

142
        self.max_seq_len = max_seq_len
143
        self.is_hf_transformers = False
Casper's avatar
Casper committed
144
        self.rope_theta = rope_theta
Casper Hansen's avatar
Casper Hansen committed
145
146
147

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
Casper's avatar
Casper committed
148
149
150
151
152
153
            attention_shapes,
            max_seq_len,
            self.cache_batch_size,
            n_heads,
            n_kv_heads,
            self.head_dim,
Casper Hansen's avatar
Casper Hansen committed
154
155
156
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
157
158
159
160
            self.attention_shapes["cache_v"],
            self.attention_shapes["cache_k"],
            self.max_seq_len,
            dev,
Casper Hansen's avatar
Casper Hansen committed
161
        )
Casper Hansen's avatar
Casper Hansen committed
162

163
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
164
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
165
166
167
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
168
            self.alibi = None
Isotr0py's avatar
Isotr0py committed
169
170
171
            self.partial_rotary_factor = partial_rotary_factor
            self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
            self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
172
            self.is_neox = True
Casper's avatar
Casper committed
173
174
175
176

    def forward(
        self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
177
        bsz, seqlen, _ = hidden_states.shape
178

Casper's avatar
Casper committed
179
        # Reallocate cache if batch size changes
180
        if bsz != self.cache_batch_size:
Casper's avatar
Casper committed
181
182
183
184
185
186
            if bsz > self.cache_batch_size:
                self.cache.increase_batch_size(bsz)
                self.cache_batch_size = bsz
            elif bsz < self.cache_batch_size:
                self.cache.decrease_batch_size(bsz)
                self.cache_batch_size = bsz
187
188

            # Always reset to 0
Casper's avatar
Casper committed
189
            self.start_pos = 0
190

191
        hf_is_generating = False
192
193
        hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None
        hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0
194
195
196
197

        if self.is_hf_transformers and "use_cache" in kwargs:
            hf_is_generating = kwargs["use_cache"]

198
        # print(kwargs["past_key_value"].get_seq_length())
199

Casper's avatar
Casper committed
200
201
        # In case we re-generate, we need to refresh the starting position
        # to 0. We detect it by checking if `past_key_values` is set to None,
202
        # which indicates that we are on the first step of `generate()`.
203
        # This is only applicable for `transformers` integration
204
        if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating):
205
            self.start_pos = 0
206
    
207

Casper Hansen's avatar
Casper Hansen committed
208
        xqkv = self.qkv_proj(hidden_states)
209
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper's avatar
Casper committed
210

211
212
213
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
214

Isotr0py's avatar
Isotr0py committed
215
        if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
216
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
217
218
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
219

220
            if not self.use_alibi:
Isotr0py's avatar
Isotr0py committed
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                # Partial rotary embedding
                if self.partial_rotary_factor < 1:
                    xq_rot, xq_pass = (
                        xq[..., : self.rotary_dim],
                        xq[..., self.rotary_dim :],
                    )
                    xk_rot, xk_pass = (
                        xk[..., : self.rotary_dim],
                        xk[..., self.rotary_dim :],
                    )
                    xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
                    xq = torch.cat((xq_rot, xq_pass), dim=-1)
                    xk = torch.cat((xk_rot, xk_pass), dim=-1)
                else:
                    xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
236

Casper Hansen's avatar
Casper Hansen committed
237
238
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
239
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
240
241
242
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper's avatar
Casper committed
243

244
            self.cache.to(xq)
Casper Hansen's avatar
Casper Hansen committed
245
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
246

Casper's avatar
Casper committed
247
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
248
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
249
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
250

Casper Hansen's avatar
Casper Hansen committed
251
252
            keys = xk
            values = xv
253
254
255

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
Casper's avatar
Casper committed
256
257
258
259
                values = torch.repeat_interleave(
                    values, dim=2, repeats=self.n_kv_groups
                )

Casper Hansen's avatar
Casper Hansen committed
260
261
262
263
264
265
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
266
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
267

268
269
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
270
271
272
273
274
                # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we 
                # need to slice it
                if attention_mask.shape[-1] != seqlen:
                    attention_mask = attention_mask[:, :, :seqlen, :seqlen]

Casper's avatar
Casper committed
275
276
277
                scores = (
                    scores + attention_mask
                )  # (bs, n_local_heads, slen, cache_len + slen)
Casper Hansen's avatar
Casper Hansen committed
278
279
280
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
281
        else:
282
283
284
285
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
286
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
287
            attention_weight = awq_ft_ext.single_query_attention(
Casper's avatar
Casper committed
288
289
290
291
292
293
294
295
296
297
298
                xq,  # query
                xk,  # key
                xv,  # value
                self.cache.k,  # key cache
                self.cache.v,  # value cache
                None,  # length per sample
                alibi_slopes,  # alibi slopes
                self.start_pos,  # timestep
                self.rotary_dim,  # rotary embedding dimension
                self.rope_theta,  # rotary embedding base
                self.is_neox,  # is neox
Casper Hansen's avatar
Casper Hansen committed
299
            )
Casper Hansen's avatar
Casper Hansen committed
300
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper's avatar
Casper committed
301

Casper Hansen's avatar
Casper Hansen committed
302
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
303
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
304

305
306
307
        if self.is_hf_transformers and not hf_is_generating:
            self.start_pos = 0

Casper Hansen's avatar
Casper Hansen committed
308
        # past_key_value is replaced with cache_v, cache_k, returning empty data
Casper's avatar
Casper committed
309
        # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
310
311
        # about past key length
        past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
312

313

314
315
316
317
318
        if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
            new_cache = DynamicCache()
            new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
            past_key_value = new_cache

319
        return attn_output, attention_weight, past_key_value