attn.py 11.2 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

9

Casper's avatar
Casper committed
10
try:
11
    import awq_ft_ext
Casper's avatar
Casper committed
12

Casper's avatar
Casper committed
13
14
15
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
16

17
18
19
HF_NEW_CACHE_FORMAT = False

import transformers
Casper's avatar
Casper committed
20

21
22
23
24
25
26
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
    from transformers.cache_utils import DynamicCache


Casper Hansen's avatar
Casper Hansen committed
27
class RoPE(nn.Module):
TechxGenus's avatar
TechxGenus committed
28
    def __init__(self, head_dim, max_seq_len, device, rope_theta):
Casper Hansen's avatar
Casper Hansen committed
29
        super(RoPE, self).__init__()
Casper's avatar
Casper committed
30

Casper Hansen's avatar
Casper Hansen committed
31
        self.freqs_cis = nn.Parameter(
Casper's avatar
Casper committed
32
            self.precompute_freqs_cis(
TechxGenus's avatar
TechxGenus committed
33
                head_dim, max_seq_len * 2, rope_theta
Casper's avatar
Casper committed
34
            ).to(device),
Casper's avatar
Casper committed
35
            requires_grad=False,
Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper's avatar
Casper committed
63

Casper Hansen's avatar
Casper Hansen committed
64
65
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
Casper's avatar
Casper committed
66

Casper Hansen's avatar
Casper Hansen committed
67
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
68

Casper's avatar
Casper committed
69

Casper Hansen's avatar
Casper Hansen committed
70
71
72
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
Casper's avatar
Casper committed
73

Casper Hansen's avatar
Casper Hansen committed
74
        # Initialize ALiBi slopes and bias
Casper's avatar
Casper committed
75
76
77
        slopes, bias = self.build_alibi_bias(
            n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
        )
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
82
83
84
85
86
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
Casper's avatar
Casper committed
87

Casper Hansen's avatar
Casper Hansen committed
88
89
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
Casper's avatar
Casper committed
90

Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
Casper's avatar
Casper committed
95
96
97
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, 1, seq_len
        )
Casper Hansen's avatar
Casper Hansen committed
98
99
100
101
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
Casper's avatar
Casper committed
102

Casper Hansen's avatar
Casper Hansen committed
103
104
105
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
106

Casper's avatar
Casper committed
107

Casper Hansen's avatar
Casper Hansen committed
108
class QuantAttentionFused(nn.Module):
Casper's avatar
Casper committed
109
110
111
112
113
114
115
116
117
118
119
120
    def __init__(
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        dev,
        max_seq_len=2048,
        use_alibi=False,
        attention_shapes=None,
        rope_theta=10000,
TechxGenus's avatar
TechxGenus committed
121
        head_dim=None,
Casper's avatar
Casper committed
122
123
        **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
124
125
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
126
127
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
128
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
TechxGenus's avatar
TechxGenus committed
129
130
131
132
133
        self.head_dim = head_dim
        
        if head_dim is None:
            self.head_dim = hidden_size // n_heads

Casper Hansen's avatar
Casper Hansen committed
134
135
136
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
137
        self.use_alibi = use_alibi
138
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
139
140
141
142

        if kwargs.get("max_new_tokens") is not None:
            max_seq_len = kwargs["max_new_tokens"]

143
        self.max_seq_len = max_seq_len
144
        self.is_hf_transformers = False
Casper's avatar
Casper committed
145
        self.rope_theta = rope_theta
Casper Hansen's avatar
Casper Hansen committed
146
147
148

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
Casper's avatar
Casper committed
149
150
151
152
153
154
            attention_shapes,
            max_seq_len,
            self.cache_batch_size,
            n_heads,
            n_kv_heads,
            self.head_dim,
Casper Hansen's avatar
Casper Hansen committed
155
156
157
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
158
159
160
161
            self.attention_shapes["cache_v"],
            self.attention_shapes["cache_k"],
            self.max_seq_len,
            dev,
Casper Hansen's avatar
Casper Hansen committed
162
        )
Casper Hansen's avatar
Casper Hansen committed
163

164
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
165
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
166
167
168
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
169
            self.alibi = None
TechxGenus's avatar
TechxGenus committed
170
            self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta)
171
172
            self.rotary_dim = self.head_dim
            self.is_neox = True
Casper's avatar
Casper committed
173
174
175
176

    def forward(
        self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
177
        bsz, seqlen, _ = hidden_states.shape
178

Casper's avatar
Casper committed
179
        # Reallocate cache if batch size changes
180
        if bsz != self.cache_batch_size:
Casper's avatar
Casper committed
181
182
183
184
185
186
            if bsz > self.cache_batch_size:
                self.cache.increase_batch_size(bsz)
                self.cache_batch_size = bsz
            elif bsz < self.cache_batch_size:
                self.cache.decrease_batch_size(bsz)
                self.cache_batch_size = bsz
187
188

            # Always reset to 0
Casper's avatar
Casper committed
189
            self.start_pos = 0
190

191
192
193
194
195
196
        hf_is_generating = False

        if self.is_hf_transformers and "use_cache" in kwargs:
            hf_is_generating = kwargs["use_cache"]


Casper's avatar
Casper committed
197
198
        # In case we re-generate, we need to refresh the starting position
        # to 0. We detect it by checking if `past_key_values` is set to None,
199
        # which indicates that we are on the first step of `generate()`.
200
        # This is only applicable for `transformers` integration
201
        if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating):
202
            self.start_pos = 0
203
    
204

Casper Hansen's avatar
Casper Hansen committed
205
        xqkv = self.qkv_proj(hidden_states)
206
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper's avatar
Casper committed
207

208
209
210
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
211

Casper's avatar
Casper committed
212
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
213
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
214
215
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
216

217
            if not self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
218
                xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
219

Casper Hansen's avatar
Casper Hansen committed
220
221
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
222
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
223
224
225
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper's avatar
Casper committed
226

227
            self.cache.to(xq)
Casper Hansen's avatar
Casper Hansen committed
228
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
229

Casper's avatar
Casper committed
230
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
231
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
232
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
233

Casper Hansen's avatar
Casper Hansen committed
234
235
            keys = xk
            values = xv
236
237
238

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
Casper's avatar
Casper committed
239
240
241
242
                values = torch.repeat_interleave(
                    values, dim=2, repeats=self.n_kv_groups
                )

Casper Hansen's avatar
Casper Hansen committed
243
244
245
246
247
248
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
249
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
250

251
252
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
253
254
255
256
257
                # For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we 
                # need to slice it
                if attention_mask.shape[-1] != seqlen:
                    attention_mask = attention_mask[:, :, :seqlen, :seqlen]

Casper's avatar
Casper committed
258
259
260
                scores = (
                    scores + attention_mask
                )  # (bs, n_local_heads, slen, cache_len + slen)
Casper Hansen's avatar
Casper Hansen committed
261
262
263
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
264
        else:
265
266
267
268
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
269
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
270
            attention_weight = awq_ft_ext.single_query_attention(
Casper's avatar
Casper committed
271
272
273
274
275
276
277
278
279
280
281
                xq,  # query
                xk,  # key
                xv,  # value
                self.cache.k,  # key cache
                self.cache.v,  # value cache
                None,  # length per sample
                alibi_slopes,  # alibi slopes
                self.start_pos,  # timestep
                self.rotary_dim,  # rotary embedding dimension
                self.rope_theta,  # rotary embedding base
                self.is_neox,  # is neox
Casper Hansen's avatar
Casper Hansen committed
282
            )
Casper Hansen's avatar
Casper Hansen committed
283
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper's avatar
Casper committed
284

Casper Hansen's avatar
Casper Hansen committed
285
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
286
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
287

288
289
290
        if self.is_hf_transformers and not hf_is_generating:
            self.start_pos = 0

Casper Hansen's avatar
Casper Hansen committed
291
        # past_key_value is replaced with cache_v, cache_k, returning empty data
Casper's avatar
Casper committed
292
        # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
293
294
        # about past key length
        past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
295

296

297
298
299
300
301
        if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
            new_cache = DynamicCache()
            new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
            past_key_value = new_cache

302
        return attn_output, attention_weight, past_key_value