attn.py 13.1 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import math
Haotian Tang's avatar
Haotian Tang committed
2
3
4
import torch
import torch.nn as nn
import awq_inference_engine
Casper Hansen's avatar
Casper Hansen committed
5
from torch.nn import functional as F
Casper Hansen's avatar
Casper Hansen committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

try:
    from flash_attn import flash_attn_func
    FLASH_INSTALLED = True
except:
    FLASH_INSTALLED = False

class QuantLlamaRotary(nn.Module):
    def __init__(self, dim=4096, max_position_embeddings=2048, base=10000, device=None, 
                       is_neox=True, num_heads=None, num_kv_heads=None):
        super().__init__()
        self.dim = dim
        self.is_neox = is_neox
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads

        # create cache
        inv_freq = 1.0 / (base**(torch.arange(0, dim, 2, device=device) / dim))
        t = torch.arange(max_position_embeddings, device=device).float()
        freqs = torch.einsum("i,j -> ij", t, inv_freq.float())
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1).to(torch.get_default_dtype())

        # Embedding size: [max_position, rotary_dim]
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
    
    def forward(
        self,
        qkv_states: torch.Tensor,
        position_ids: torch.Tensor,
        batch_size: int, 
        q_len: int
    ):
        # get qkv
        query, key, value = qkv_states.chunk(chunks=3, dim=-1)
        del qkv_states

        # [num_tokens, num_heads * head_size]
        query_batch_size, query_len, _ = query.shape
        query = query.view(query_len*query_batch_size, self.num_heads * self.dim)

        # [num_tokens, num_kv_heads * head_size]
        key_batch_size, key_len, _ = key.shape
        key = key.view(key_len*key_batch_size, self.num_kv_heads * self.dim)

        # [num_tokens]
        positions = position_ids.view(-1).to(query.device)

        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        query = query.contiguous()
        key = key.contiguous()

        awq_inference_engine.rotary_embedding(
            positions,
            query,
            key,
            self.dim,
            self.cos_sin_cache,
            self.is_neox
        )

        query = query.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)
        key = key.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)
        value = value.view(batch_size, q_len, self.num_heads, self.dim).transpose(1, 2)

        return query, key, value

Haotian Tang's avatar
Haotian Tang committed
75
76
77
78
79
80
81
82

class QuantLlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
Casper Hansen's avatar
Casper Hansen committed
83
84
85
        inv_freq = 1.0 / (
            self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        )
Haotian Tang's avatar
Haotian Tang committed
86
87
88
        self.register_buffer("inv_freq", inv_freq)
        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
Casper Hansen's avatar
Casper Hansen committed
89
90
91
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype(),
Haotian Tang's avatar
Haotian Tang committed
92
93
94
95
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
Casper Hansen's avatar
Casper Hansen committed
96
97
98
        t = torch.arange(
            self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
        )
Haotian Tang's avatar
Haotian Tang committed
99
100
101
102

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
103

Haotian Tang's avatar
Haotian Tang committed
104
105
106
        cos = freqs.cos()
        sin = freqs.sin()
        cache = torch.cat((cos, sin), dim=-1)
Casper Hansen's avatar
Casper Hansen committed
107
108
109

        # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
Haotian Tang's avatar
Haotian Tang committed
110
        self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
Casper Hansen's avatar
Casper Hansen committed
111

Haotian Tang's avatar
Haotian Tang committed
112
113
114
115
116
117
118
119
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        positions: torch.Tensor,
    ):
        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
Casper Hansen's avatar
Casper Hansen committed
120
        # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
Haotian Tang's avatar
Haotian Tang committed
121
122
        query = query.contiguous()
        key = key.contiguous()
Casper Hansen's avatar
Casper Hansen committed
123
        awq_inference_engine.rotary_embedding(
Haotian Tang's avatar
Haotian Tang committed
124
125
126
127
128
            positions,
            query,
            key,
            self.dim,
            self.cos_sin_cache,
Casper Hansen's avatar
Casper Hansen committed
129
            True
Haotian Tang's avatar
Haotian Tang committed
130
131
132
        )
        return query, key

Casper Hansen's avatar
Casper Hansen committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

class TorchAttention(nn.Module):
    def __init__(self, hidden_size, use_flash=False):
        super().__init__()
        self.hidden_size = hidden_size
        self.use_flash = use_flash
    
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        use_cache: bool,
        past_key_value: torch.Tensor,
        batch_size: int, 
        q_len: int
    ):
        is_causal = past_key_value is None

        kv_seq_len = q_len
        if past_key_value is not None:
            kv_seq_len += past_key_value[0].shape[-2]
        
        value = value.to(key.device)

        if past_key_value is not None:
            # reuse k, v, self_attention
            key = torch.cat([past_key_value[0], key], dim=2)
            value = torch.cat([past_key_value[1], value], dim=2)

        if use_cache:
            # Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
            # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
            key = key.contiguous()
            value = value.contiguous()
            query = query.contiguous()

        past_key_value = (key, value) if use_cache else None

        if self.use_flash and FLASH_INSTALLED:
            query = query.transpose(1,2)
            key = key.transpose(1,2)
            value = value.transpose(1,2)
            attn_output = flash_attn_func(query, key, value, causal=is_causal)
        else:
            attn_output = F.scaled_dot_product_attention(query, key, value, is_causal=is_causal)
        
        del query, key, value

        attn_output = attn_output.transpose(1, 2).reshape(batch_size, q_len, self.hidden_size)

        return attn_output, past_key_value

Haotian Tang's avatar
Haotian Tang committed
186
187
188
189
190
191
192
class QuantLlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size,
        num_heads,
Casper Hansen's avatar
Casper Hansen committed
193
        num_kv_heads,
Haotian Tang's avatar
Haotian Tang committed
194
195
        qkv_proj,
        o_proj,
Casper Hansen's avatar
Casper Hansen committed
196
        dev,
Casper Hansen's avatar
Casper Hansen committed
197
        max_new_tokens
Haotian Tang's avatar
Haotian Tang committed
198
199
200
201
    ):
        super().__init__()
        self.qkv_proj = qkv_proj
        self.o_proj = o_proj
Casper Hansen's avatar
Casper Hansen committed
202
203
204
205
206
207
208
209
210
211
        self.attn = TorchAttention(hidden_size)

        self.rotary_emb = QuantLlamaRotary(
            dim=hidden_size // num_heads,
            max_position_embeddings=max_new_tokens,
            device=dev,
            is_neox=True,
            num_heads=num_heads,
            num_kv_heads=num_heads
        )
Casper Hansen's avatar
Casper Hansen committed
212

Haotian Tang's avatar
Haotian Tang committed
213
214
215
216

    def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
        """Input shape: Batch x Time x Channel"""

Casper Hansen's avatar
Casper Hansen committed
217
        batch_size, q_len, _ = hidden_states.size()
Haotian Tang's avatar
Haotian Tang committed
218
        qkv_states = self.qkv_proj(hidden_states)
Casper Hansen's avatar
Casper Hansen committed
219
220
221
        query, key, value = self.rotary_emb(qkv_states, position_ids, batch_size, q_len)
        attn_output, past_key_value = self.attn(query, key, value, use_cache, past_key_value, batch_size, q_len)
        attn_output = self.o_proj(attn_output)
Casper Hansen's avatar
Casper Hansen committed
222

Casper Hansen's avatar
Casper Hansen committed
223
        return attn_output, None, past_key_value
Casper Hansen's avatar
Casper Hansen committed
224

Casper Hansen's avatar
Casper Hansen committed
225
226
227
228
229
230
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis
Casper Hansen's avatar
Casper Hansen committed
231

Casper Hansen's avatar
Casper Hansen committed
232
233
234
235
236
237
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)
Casper Hansen's avatar
Casper Hansen committed
238

Casper Hansen's avatar
Casper Hansen committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
):
    xq_ = torch.view_as_complex(
        xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    xk_ = torch.view_as_complex(
        xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
    )
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
Casper Hansen's avatar
Casper Hansen committed
254
255


Casper Hansen's avatar
Casper Hansen committed
256
257
258
259
260
261
262
263
264
class QuantLlamaAttentionFused(nn.Module):
    def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_local_heads = num_heads
        self.head_dim = self.hidden_size // num_heads
        self.qkv_proj = qkv_layer
        self.o_proj = o_proj
        self.start_pos = 0
Casper Hansen's avatar
Casper Hansen committed
265

Casper Hansen's avatar
Casper Hansen committed
266
267
268
269
        self.freqs_cis = precompute_freqs_cis(
            self.head_dim ,
            max_position_embeddings * 2,
        )
Haotian Tang's avatar
Haotian Tang committed
270

Casper Hansen's avatar
Casper Hansen committed
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
        # following fastertransformer definition
        self.cache_v = (
            torch.zeros(
                (
                    1,
                    self.n_local_heads,
                    max_position_embeddings,
                    self.head_dim,
                )
            )
            .to(dev)
            .half()
        )  # added to half
        # 8: pack 8 fp16 in FT, if fp32 then use 4
        self.cache_k = (
            torch.zeros(
                (
                    1,
                    self.n_local_heads,
                    self.head_dim // 8,
                    max_position_embeddings,
                    8,
                )
            )
            .to(dev)
            .half()
        )  # added to half

        # dummy
        self.rotary_emb = QuantLlamaRotaryEmbedding(
            hidden_size // num_heads, max_position_embeddings=max_position_embeddings, base=10000, device=dev
        )
Haotian Tang's avatar
Haotian Tang committed
303
        
Casper Hansen's avatar
Casper Hansen committed
304
305
306
307
308
309
310
311
312
313
    def forward(
        self,
        hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
    ):
        bsz, seqlen, _ = hidden_states.shape
        xqkv = self.qkv_proj(hidden_states)
        xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim)
        xq = xqkv[:, :, 0]
        xk = xqkv[:, :, 1]
        xv = xqkv[:, :, 2]
Haotian Tang's avatar
Haotian Tang committed
314

Casper Hansen's avatar
Casper Hansen committed
315
316
317
318
        if seqlen > 1:
            xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
            xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Haotian Tang's avatar
Haotian Tang committed
319

Casper Hansen's avatar
Casper Hansen committed
320
            xq, xk = self.rotary_emb(xq, xk, position_ids)
Haotian Tang's avatar
Haotian Tang committed
321

Casper Hansen's avatar
Casper Hansen committed
322
323
            self.cache_k = self.cache_k.to(xq)
            self.cache_v = self.cache_v.to(xq)
Haotian Tang's avatar
Haotian Tang committed
324

Casper Hansen's avatar
Casper Hansen committed
325
326
327
328
329
330
            values_store = xv.transpose(2, 1)
            keys_store = (
                xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8)
                .permute(0, 2, 3, 1, 4)
                .contiguous()
            )
Haotian Tang's avatar
Haotian Tang committed
331

Casper Hansen's avatar
Casper Hansen committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
            self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
            self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store

            keys = xk
            values = xv
            past_key_value = (xk, xv) if use_cache else None

            xq = xq.transpose(1, 2)
            keys = keys.transpose(1, 2)
            values = values.transpose(1, 2)
            scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
            if attention_mask is not None:
                scores = scores + attention_mask  # (bs, n_local_heads, slen, cache_len + slen)
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            output = torch.matmul(scores, values)  # (bs, n_local_heads, slen, head_dim)
            output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        else:
            xq = xq[:, 0, :, :]
            xk = xk[:, 0, :, :]
            xv = xv[:, 0, :, :]
            past_key_value = (xk, xv) if use_cache else None
            output = awq_inference_engine.single_query_attention(
                xq,
                xk,
                xv,
                self.cache_k,
                self.cache_v,
                None,
                None,
                self.start_pos,
                self.head_dim,
                10000,
                True,
            )
            output = output.reshape(bsz, 1, -1)
        
        attn_output = self.o_proj(output)
        
        if use_cache:
            self.start_pos += seqlen
        else:
            self.start_pos = 0
Haotian Tang's avatar
Haotian Tang committed
374
375

        return attn_output, None, past_key_value