attn.py 10.6 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

9

Casper's avatar
Casper committed
10
try:
11
    import awq_ft_ext
Casper's avatar
Casper committed
12

Casper's avatar
Casper committed
13
14
15
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
16

17
18
19
HF_NEW_CACHE_FORMAT = False

import transformers
Casper's avatar
Casper committed
20

21
22
23
24
25
26
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
    from transformers.cache_utils import DynamicCache


Casper Hansen's avatar
Casper Hansen committed
27
class RoPE(nn.Module):
Casper's avatar
Casper committed
28
    def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
Casper Hansen's avatar
Casper Hansen committed
29
        super(RoPE, self).__init__()
Casper's avatar
Casper committed
30

Casper Hansen's avatar
Casper Hansen committed
31
        self.freqs_cis = nn.Parameter(
Casper's avatar
Casper committed
32
33
34
            self.precompute_freqs_cis(
                hidden_size // n_heads, max_seq_len * 2, rope_theta
            ).to(device),
Casper's avatar
Casper committed
35
            requires_grad=False,
Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper's avatar
Casper committed
63

Casper Hansen's avatar
Casper Hansen committed
64
65
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
Casper's avatar
Casper committed
66

Casper Hansen's avatar
Casper Hansen committed
67
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
68

Casper's avatar
Casper committed
69

Casper Hansen's avatar
Casper Hansen committed
70
71
72
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
Casper's avatar
Casper committed
73

Casper Hansen's avatar
Casper Hansen committed
74
        # Initialize ALiBi slopes and bias
Casper's avatar
Casper committed
75
76
77
        slopes, bias = self.build_alibi_bias(
            n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
        )
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
82
83
84
85
86
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
Casper's avatar
Casper committed
87

Casper Hansen's avatar
Casper Hansen committed
88
89
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
Casper's avatar
Casper committed
90

Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
Casper's avatar
Casper committed
95
96
97
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, 1, seq_len
        )
Casper Hansen's avatar
Casper Hansen committed
98
99
100
101
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
Casper's avatar
Casper committed
102

Casper Hansen's avatar
Casper Hansen committed
103
104
105
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
106

Casper's avatar
Casper committed
107

Casper Hansen's avatar
Casper Hansen committed
108
class QuantAttentionFused(nn.Module):
Casper's avatar
Casper committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    def __init__(
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        dev,
        max_seq_len=2048,
        use_alibi=False,
        attention_shapes=None,
        rope_theta=10000,
        **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
123
124
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
125
126
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
127
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
Casper Hansen's avatar
Casper Hansen committed
128
        self.head_dim = self.hidden_size // n_heads
Casper Hansen's avatar
Casper Hansen committed
129
130
131
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
132
        self.use_alibi = use_alibi
133
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
134
135
136
137

        if kwargs.get("max_new_tokens") is not None:
            max_seq_len = kwargs["max_new_tokens"]

138
        self.max_seq_len = max_seq_len
139
        self.is_hf_transformers = False
Casper's avatar
Casper committed
140
        self.rope_theta = rope_theta
Casper Hansen's avatar
Casper Hansen committed
141
142
143

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
Casper's avatar
Casper committed
144
145
146
147
148
149
            attention_shapes,
            max_seq_len,
            self.cache_batch_size,
            n_heads,
            n_kv_heads,
            self.head_dim,
Casper Hansen's avatar
Casper Hansen committed
150
151
152
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
153
154
155
156
            self.attention_shapes["cache_v"],
            self.attention_shapes["cache_k"],
            self.max_seq_len,
            dev,
Casper Hansen's avatar
Casper Hansen committed
157
        )
Casper Hansen's avatar
Casper Hansen committed
158

159
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
160
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
161
162
163
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
164
            self.alibi = None
Casper's avatar
Casper committed
165
            self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
166
167
            self.rotary_dim = self.head_dim
            self.is_neox = True
Casper's avatar
Casper committed
168
169
170
171

    def forward(
        self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
172
        bsz, seqlen, _ = hidden_states.shape
173

Casper's avatar
Casper committed
174
        # Reallocate cache if batch size changes
175
        if bsz != self.cache_batch_size:
Casper's avatar
Casper committed
176
177
178
179
180
181
            if bsz > self.cache_batch_size:
                self.cache.increase_batch_size(bsz)
                self.cache_batch_size = bsz
            elif bsz < self.cache_batch_size:
                self.cache.decrease_batch_size(bsz)
                self.cache_batch_size = bsz
182
183

            # Always reset to 0
Casper's avatar
Casper committed
184
            self.start_pos = 0
185

Casper's avatar
Casper committed
186
187
        # In case we re-generate, we need to refresh the starting position
        # to 0. We detect it by checking if `past_key_values` is set to None,
188
        # which indicates that we are on the first step of `generate()`.
189
        # This is only applicable for `transformers` integration
Casper's avatar
Casper committed
190
191
192
193
194
        if (
            self.is_hf_transformers
            and "past_key_value" in kwargs
            and kwargs["past_key_value"] is None
        ):
195
196
            self.start_pos = 0

Casper Hansen's avatar
Casper Hansen committed
197
        xqkv = self.qkv_proj(hidden_states)
198
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper's avatar
Casper committed
199

200
201
202
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
203

Casper's avatar
Casper committed
204
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
205
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
206
207
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
208

209
            if not self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
210
                xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
211

Casper Hansen's avatar
Casper Hansen committed
212
            self.cache.to(xq)
Haotian Tang's avatar
Haotian Tang committed
213

Casper Hansen's avatar
Casper Hansen committed
214
215
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
216
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
217
218
219
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper's avatar
Casper committed
220

Casper Hansen's avatar
Casper Hansen committed
221
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
222

Casper's avatar
Casper committed
223
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
224
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
225
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
226

Casper Hansen's avatar
Casper Hansen committed
227
228
            keys = xk
            values = xv
229
230
231

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
Casper's avatar
Casper committed
232
233
234
235
                values = torch.repeat_interleave(
                    values, dim=2, repeats=self.n_kv_groups
                )

Casper Hansen's avatar
Casper Hansen committed
236
237
238
239
240
241
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
242
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
243

244
245
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
Casper's avatar
Casper committed
246
247
248
                scores = (
                    scores + attention_mask
                )  # (bs, n_local_heads, slen, cache_len + slen)
Casper Hansen's avatar
Casper Hansen committed
249
250
251
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
252
        else:
253
254
255
256
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
257
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
258
            attention_weight = awq_ft_ext.single_query_attention(
Casper's avatar
Casper committed
259
260
261
262
263
264
265
266
267
268
269
                xq,  # query
                xk,  # key
                xv,  # value
                self.cache.k,  # key cache
                self.cache.v,  # value cache
                None,  # length per sample
                alibi_slopes,  # alibi slopes
                self.start_pos,  # timestep
                self.rotary_dim,  # rotary embedding dimension
                self.rope_theta,  # rotary embedding base
                self.is_neox,  # is neox
Casper Hansen's avatar
Casper Hansen committed
270
            )
Casper Hansen's avatar
Casper Hansen committed
271
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper's avatar
Casper committed
272

Casper Hansen's avatar
Casper Hansen committed
273
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
274
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
275

Casper Hansen's avatar
Casper Hansen committed
276
        # past_key_value is replaced with cache_v, cache_k, returning empty data
Casper's avatar
Casper committed
277
        # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
278
279
        # about past key length
        past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
280
281
282
283
284
285

        if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
            new_cache = DynamicCache()
            new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
            past_key_value = new_cache

286
        return attn_output, attention_weight, past_key_value