attn.py 10.6 KB
Newer Older
1
import os
Casper Hansen's avatar
Casper Hansen committed
2
import math
Haotian Tang's avatar
Haotian Tang committed
3
4
import torch
import torch.nn as nn
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
from awq.modules.fused.cache import WindowedCache
7
from awq.utils.fused_utils import get_attention_shapes
Casper Hansen's avatar
Casper Hansen committed
8

9

Casper's avatar
Casper committed
10
try:
11
    import awq_ft_ext
Casper's avatar
Casper committed
12

Casper's avatar
Casper committed
13
14
15
    FT_INSTALLED = True
except:
    FT_INSTALLED = False
qwopqwop200's avatar
qwopqwop200 committed
16

17
18
19
HF_NEW_CACHE_FORMAT = False

import transformers
Casper's avatar
Casper committed
20

21
22
23
24
25
26
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
    from transformers.cache_utils import DynamicCache


Casper Hansen's avatar
Casper Hansen committed
27
class RoPE(nn.Module):
TechxGenus's avatar
TechxGenus committed
28
    def __init__(self, head_dim, max_seq_len, device, rope_theta):
Casper Hansen's avatar
Casper Hansen committed
29
        super(RoPE, self).__init__()
Casper's avatar
Casper committed
30

Casper Hansen's avatar
Casper Hansen committed
31
        self.freqs_cis = nn.Parameter(
Casper's avatar
Casper committed
32
            self.precompute_freqs_cis(
TechxGenus's avatar
TechxGenus committed
33
                head_dim, max_seq_len * 2, rope_theta
Casper's avatar
Casper committed
34
            ).to(device),
Casper's avatar
Casper committed
35
            requires_grad=False,
Casper Hansen's avatar
Casper Hansen committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        )

    @staticmethod
    def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    @staticmethod
    def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
        ndim = x.ndim
        assert 0 <= 1 < ndim
        assert freqs_cis.shape == (x.shape[1], x.shape[-1])
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
        return freqs_cis.view(*shape)

    def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
        xq_ = torch.view_as_complex(
            xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        xk_ = torch.view_as_complex(
            xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
        )
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
        freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
Casper's avatar
Casper committed
63

Casper Hansen's avatar
Casper Hansen committed
64
65
        xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
        xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
Casper's avatar
Casper committed
66

Casper Hansen's avatar
Casper Hansen committed
67
        return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
68

Casper's avatar
Casper committed
69

Casper Hansen's avatar
Casper Hansen committed
70
71
72
class ALiBi(nn.Module):
    def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
        super(ALiBi, self).__init__()
Casper's avatar
Casper committed
73

Casper Hansen's avatar
Casper Hansen committed
74
        # Initialize ALiBi slopes and bias
Casper's avatar
Casper committed
75
76
77
        slopes, bias = self.build_alibi_bias(
            n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
        )
Casper Hansen's avatar
Casper Hansen committed
78
79
80
81
82
83
84
85
86
        self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
        self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

    @staticmethod
    def gen_slopes(n_heads, alibi_bias_max=8):
        _n_heads = 2 ** math.ceil(math.log2(n_heads))
        m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
        m = m.mul(alibi_bias_max / _n_heads)
        slopes = 1.0 / torch.pow(2, m)
Casper's avatar
Casper committed
87

Casper Hansen's avatar
Casper Hansen committed
88
89
        if _n_heads != n_heads:
            slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
Casper's avatar
Casper committed
90

Casper Hansen's avatar
Casper Hansen committed
91
92
93
94
        return slopes.view(1, n_heads, 1, 1)

    @staticmethod
    def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
Casper's avatar
Casper committed
95
96
97
        alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
            1, 1, 1, seq_len
        )
Casper Hansen's avatar
Casper Hansen committed
98
99
100
101
        slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
        alibi_bias = alibi_bias * slopes
        slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
        return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
Casper's avatar
Casper committed
102

Casper Hansen's avatar
Casper Hansen committed
103
104
105
    def forward(self, scores, seqlen):
        scores += self.bias[..., :seqlen]
        return scores
Casper Hansen's avatar
Casper Hansen committed
106

Casper's avatar
Casper committed
107

Casper Hansen's avatar
Casper Hansen committed
108
class QuantAttentionFused(nn.Module):
Casper's avatar
Casper committed
109
110
111
112
113
114
115
116
117
118
119
120
    def __init__(
        self,
        hidden_size,
        n_heads,
        n_kv_heads,
        qkv_layer,
        o_proj,
        dev,
        max_seq_len=2048,
        use_alibi=False,
        attention_shapes=None,
        rope_theta=10000,
TechxGenus's avatar
TechxGenus committed
121
        head_dim=None,
Casper's avatar
Casper committed
122
123
        **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
124
125
        super().__init__()
        self.hidden_size = hidden_size
Casper Hansen's avatar
Casper Hansen committed
126
127
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
128
        self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
TechxGenus's avatar
TechxGenus committed
129
130
131
132
133
        self.head_dim = head_dim
        
        if head_dim is None:
            self.head_dim = hidden_size // n_heads

Casper Hansen's avatar
Casper Hansen committed
134
135
136
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
137
        self.use_alibi = use_alibi
138
        self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
Casper's avatar
Casper committed
139
140
141
142

        if kwargs.get("max_new_tokens") is not None:
            max_seq_len = kwargs["max_new_tokens"]

143
        self.max_seq_len = max_seq_len
144
        self.is_hf_transformers = False
Casper's avatar
Casper committed
145
        self.rope_theta = rope_theta
Casper Hansen's avatar
Casper Hansen committed
146
147
148

        # attention shapes for self attention
        self.attention_shapes = get_attention_shapes(
Casper's avatar
Casper committed
149
150
151
152
153
154
            attention_shapes,
            max_seq_len,
            self.cache_batch_size,
            n_heads,
            n_kv_heads,
            self.head_dim,
Casper Hansen's avatar
Casper Hansen committed
155
156
157
        )
        # cache store that rolls cache
        self.cache = WindowedCache(
Casper's avatar
Casper committed
158
159
160
161
            self.attention_shapes["cache_v"],
            self.attention_shapes["cache_k"],
            self.max_seq_len,
            dev,
Casper Hansen's avatar
Casper Hansen committed
162
        )
Casper Hansen's avatar
Casper Hansen committed
163

164
        if use_alibi:
Casper Hansen's avatar
Casper Hansen committed
165
            self.alibi = ALiBi(n_heads, max_seq_len, dev)
166
167
168
            self.rotary_dim = 0
            self.is_neox = False
        else:
Casper Hansen's avatar
Casper Hansen committed
169
            self.alibi = None
TechxGenus's avatar
TechxGenus committed
170
            self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta)
171
172
            self.rotary_dim = self.head_dim
            self.is_neox = True
Casper's avatar
Casper committed
173
174
175
176

    def forward(
        self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
    ):
Casper Hansen's avatar
Casper Hansen committed
177
        bsz, seqlen, _ = hidden_states.shape
178

Casper's avatar
Casper committed
179
        # Reallocate cache if batch size changes
180
        if bsz != self.cache_batch_size:
Casper's avatar
Casper committed
181
182
183
184
185
186
            if bsz > self.cache_batch_size:
                self.cache.increase_batch_size(bsz)
                self.cache_batch_size = bsz
            elif bsz < self.cache_batch_size:
                self.cache.decrease_batch_size(bsz)
                self.cache_batch_size = bsz
187
188

            # Always reset to 0
Casper's avatar
Casper committed
189
            self.start_pos = 0
190

Casper's avatar
Casper committed
191
192
        # In case we re-generate, we need to refresh the starting position
        # to 0. We detect it by checking if `past_key_values` is set to None,
193
        # which indicates that we are on the first step of `generate()`.
194
        # This is only applicable for `transformers` integration
Casper's avatar
Casper committed
195
196
197
198
199
        if (
            self.is_hf_transformers
            and "past_key_value" in kwargs
            and kwargs["past_key_value"] is None
        ):
200
201
            self.start_pos = 0

Casper Hansen's avatar
Casper Hansen committed
202
        xqkv = self.qkv_proj(hidden_states)
203
        xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Casper's avatar
Casper committed
204

205
206
207
        xq = self.attention_shapes["xq_slice"](xqkv)
        xk = self.attention_shapes["xk_slice"](xqkv)
        xv = self.attention_shapes["xv_slice"](xqkv)
Haotian Tang's avatar
Haotian Tang committed
208

Casper's avatar
Casper committed
209
        if seqlen > 1 or not FT_INSTALLED:
Casper Hansen's avatar
Casper Hansen committed
210
            xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
211
212
            xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
            xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
Haotian Tang's avatar
Haotian Tang committed
213

214
            if not self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
215
                xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
Haotian Tang's avatar
Haotian Tang committed
216

Casper Hansen's avatar
Casper Hansen committed
217
            self.cache.to(xq)
Haotian Tang's avatar
Haotian Tang committed
218

Casper Hansen's avatar
Casper Hansen committed
219
220
            values_store = xv.transpose(2, 1)
            keys_store = (
Casper Hansen's avatar
Casper Hansen committed
221
                xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
Casper Hansen's avatar
Casper Hansen committed
222
223
224
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Casper's avatar
Casper committed
225

Casper Hansen's avatar
Casper Hansen committed
226
            self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
Casper Hansen's avatar
Casper Hansen committed
227

Casper's avatar
Casper committed
228
            # Only necessary to retrieve from cache when we are not processing context
qwopqwop200's avatar
fix bug  
qwopqwop200 committed
229
            if seqlen == 1:
Casper Hansen's avatar
Casper Hansen committed
230
                xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
231

Casper Hansen's avatar
Casper Hansen committed
232
233
            keys = xk
            values = xv
234
235
236

            if self.n_kv_groups != 0:
                keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
Casper's avatar
Casper committed
237
238
239
240
                values = torch.repeat_interleave(
                    values, dim=2, repeats=self.n_kv_groups
                )

Casper Hansen's avatar
Casper Hansen committed
241
242
243
244
245
246
            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

            if self.use_alibi:
Casper Hansen's avatar
Casper Hansen committed
247
                scores = self.alibi.forward(scores, seqlen)
Casper Hansen's avatar
Casper Hansen committed
248

249
250
            # When seqlen is 1, there is nothing else to attend to
            if attention_mask is not None and seqlen > 1:
Casper's avatar
Casper committed
251
252
253
                scores = (
                    scores + attention_mask
                )  # (bs, n_local_heads, slen, cache_len + slen)
Casper Hansen's avatar
Casper Hansen committed
254
255
256
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
Casper Hansen's avatar
Casper Hansen committed
257
        else:
258
259
260
261
            xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
            xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
            xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

Casper Hansen's avatar
Casper Hansen committed
262
            alibi_slopes = self.alibi.slopes if self.alibi is not None else None
263
            attention_weight = awq_ft_ext.single_query_attention(
Casper's avatar
Casper committed
264
265
266
267
268
269
270
271
272
273
274
                xq,  # query
                xk,  # key
                xv,  # value
                self.cache.k,  # key cache
                self.cache.v,  # value cache
                None,  # length per sample
                alibi_slopes,  # alibi slopes
                self.start_pos,  # timestep
                self.rotary_dim,  # rotary embedding dimension
                self.rope_theta,  # rotary embedding base
                self.is_neox,  # is neox
Casper Hansen's avatar
Casper Hansen committed
275
            )
Casper Hansen's avatar
Casper Hansen committed
276
            attention_weight = attention_weight.reshape(bsz, 1, -1)
Casper's avatar
Casper committed
277

Casper Hansen's avatar
Casper Hansen committed
278
        attn_output = self.o_proj(attention_weight)
Casper Hansen's avatar
Casper Hansen committed
279
        self.start_pos += seqlen
Haotian Tang's avatar
Haotian Tang committed
280

Casper Hansen's avatar
Casper Hansen committed
281
        # past_key_value is replaced with cache_v, cache_k, returning empty data
Casper's avatar
Casper committed
282
        # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
283
284
        # about past key length
        past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
285
286
287
288
289
290

        if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
            new_cache = DynamicCache()
            new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
            past_key_value = new_cache

291
        return attn_output, attention_weight, past_key_value