attention.py 36.3 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
6
7
8
9
'''
Description  :  
Author       : Boxin Zhang
Version      : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. 
'''
import torch
from torch import nn
import warnings
chenxl's avatar
chenxl committed
10
11
import torch.nn.functional as F
from ktransformers.operators.models import KLlamaModel
chenxl's avatar
chenxl committed
12
from ktransformers.models.configuration_deepseek import DeepseekV2Config
chenxl's avatar
chenxl committed
13
14
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.models.modeling_llama import LlamaRotaryEmbedding
chenxl's avatar
chenxl committed
15
16
17
18
from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_rotary_pos_emb
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
19
from ktransformers.util.utils import get_compute_capability
chenxl's avatar
chenxl committed
20
import logging
chenxl's avatar
chenxl committed
21
22
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
23
from flash_attn import flash_attn_func
24
from ktransformers.operators.triton_attention import decode_attention_fwd_grouped
Atream's avatar
Atream committed
25
import os
26
27
28
29
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
    from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton, attention_ref

chenxl's avatar
chenxl committed
30
logger = logging.getLogger("attention")
31

Atream's avatar
Atream committed
32
33
34
35
36
37
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
38
39
40

# V3 MLA is same to V2
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
41
42
43
44
45
46
47
48
    """Multi-headed attention from 'Attention Is All You Need' paper"""
    attn_mask: Optional[torch.Tensor] = None

    def __init__(self,
                 key: str,
                 gguf_loader : GGUFLoader,
                 config: PretrainedConfig,
                 orig_module: nn.Module,
49
50
                 prefill_device: str = "cuda",
                 generate_device: str = "cuda",
51
                 chunck_size: int = 1000,
52
                 absorb_for_prefill: bool = False,
53
                 **kwargs):
54
        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
55
56
57
        self.orig_module.__init__(orig_module.config,
            orig_module.layer_idx)
        self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
58
        self.mla_wrapper = None
59
        self.absorb_for_prefill = absorb_for_prefill
60
61
62
63

    def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
        if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
            kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
Atream's avatar
Atream committed
64
65
66
67
            self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
            self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
            
        return self.q_absorb, self.out_absorb
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    def forward_chunck(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()
        if self.q_lora_rank is None:
            q = self.q_proj(hidden_states)
        else:
            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2)
        q_nope, q_pe = torch.split(
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )
Atream's avatar
Atream committed
89
90
        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
91
92
93
94
95
96
97
98
99
100
101
102

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, k_pe = torch.split(
            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
        )
        compressed_kv = self.kv_a_layernorm(compressed_kv)
        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2)

        kv_seq_len = k_pe.shape[-2]
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
Atream's avatar
Atream committed
103
                    f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
104
105
106
107
108
109
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

        cos, sin = self.rotary_emb(q_pe, position_ids)
110
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin)
111
112
113

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
Atream's avatar
Atream committed
114
115
116
117
118
119
120
121
122
123
124
125
            
            # compressed_kv [bsz, q_len, self.kv_lora_rank]
            # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
            k_pe = k_pe.transpose(1,2)
            compressed_kv = compressed_kv.unsqueeze(2)
            compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
            compressed_kv, k_pe = torch.split(
                compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            # k_pe [pages, page_size, 1, self.qk_rope_head_dim]
            # compressed_kv [pages, page_size, 1, self.kv_lora_rank]
            
126
127
        q_absorb, out_absorb = self.get_absorbed()

Atream's avatar
Atream committed
128
129
130
131
132
133
        # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
        # q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
        k_pe = k_pe.view(bsz, 1, -1, self.qk_rope_head_dim)[:,:,:attention_mask.size(-1),:]
        compressed_kv = compressed_kv.view(bsz, 1, -1, self.kv_lora_rank)[:,:,:attention_mask.size(-1),:]
        # k_pe [bsz, 1, cache_len, self.qk_rope_head_dim]
        # compressed_kv [bsz, 1, cache_len,self.kv_lora_rank]
134
        q_nope = torch.matmul(q_nope, q_absorb)
Atream's avatar
Atream committed
135
136
137
138
139
140
        #print(q_pe.shape)
        #print(k_pe.shape)
        #print(q_nope.shape)
        #print(compressed_kv.shape)
        
        attn_weights = (torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.mT)) * self.softmax_scale
141
        
Atream's avatar
Atream committed
142
143
        #attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
        compressed_kv = compressed_kv.squeeze(1)
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        """
        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )
        assert attention_mask is not None
        """
        if attention_mask is not None:
            """
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            """
            #causal_mask = attention_mask[:, :, :, : kv_seq_len]
            attn_weights = attn_weights + attention_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(q_pe.dtype)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.attention_dropout, training=self.training
        )
169
        
170
        attn_output = torch.einsum('bhql,blc->bhqc', attn_weights, compressed_kv)
171
        
172
173
174
175
176
177
178
179
180
        attn_output = torch.matmul(attn_output, out_absorb.mT) 

        if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
181
        
182
183
184
185
        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)

        attn_output = self.o_proj(attn_output)

186
        return attn_output, None, past_key_value
187

188
    def forward_linux_triton(
Atream's avatar
Atream committed
189
190
191
192
193
194
195
196
197
198
199
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.LongTensor] = None,
            **kwargs,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

chenxl's avatar
chenxl committed
200
        bsz, q_len, _ = hidden_states.size()
201

chenxl's avatar
chenxl committed
202
203
204
205
        if self.q_lora_rank is None:
            q = self.q_proj(hidden_states)
        else:
            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
Atream's avatar
Atream committed
206
        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
chenxl's avatar
chenxl committed
207
208
209
210
211
212
213
214
215
        q_nope, q_pe = torch.split(
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, k_pe = torch.split(
            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
        )
        compressed_kv = self.kv_a_layernorm(compressed_kv)
Atream's avatar
Atream committed
216
        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
217
        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
218
219
220
221
222

        kv_seq_len = q_len
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
Atream's avatar
Atream committed
223
                    f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
224
225
226
227
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
228
        
chenxl's avatar
chenxl committed
229
        cos, sin = self.rotary_emb(q_pe, position_ids)
Atream's avatar
Atream committed
230
231
232
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
        
233
234
235
236
        # decode
        if q_len == 1:
            if past_key_value is not None:
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
Atream's avatar
Atream committed
237
238
239
240
241
242
                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank] # for speed
                # compressed_kv_with_k_pe [bsz, q_len, 1, self.kv_lora_rank + self.qk_rope_head_dim]
                # compressed_kv [bsz, q_len, 1, self.kv_lora_rank]

            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
243
244
            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
            q_absorb, out_absorb = self.get_absorbed()
Atream's avatar
Atream committed
245
            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
246
            q_nope = torch.matmul(q_nope, q_absorb) # batched MM
Atream's avatar
Atream committed
247
            q_nope = q_nope.transpose(1, 2)
248
            #assert q_nope.is_contiguous()
Atream's avatar
Atream committed
249
250
251
            
            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
252
253
            query_states = torch.cat([q_nope, q_pe], dim=-1)
            
Atream's avatar
Atream committed
254
255
            query_states = query_states.squeeze(1)
            attn_output = torch.zeros_like(q_nope) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
256
257
258
259
260
            
            attn_logits = torch.empty(
                    (
                        bsz,
                        self.num_heads,
Atream's avatar
Atream committed
261
                        4, #num_kv_splits # follow vLLM, fix it TODO
262
263
264
265
266
                        self.kv_lora_rank + 1, 
                    ),
                    dtype=torch.float32,
                    device = attn_output.device
                )
chenxl's avatar
chenxl committed
267
268

            """
269
            print("query_states", torch.isnan(query_states).any())
Atream's avatar
Atream committed
270
            print("compressed_kv_with_k_pe", torch.isnan(compressed_kv_with_k_pe[:,:,0,:]).any())
271
272
            print("compressed_kv", torch.isnan(compressed_kv[:,:,0,:]).any())
            print("position_ids", torch.isnan(position_ids).any())
chenxl's avatar
chenxl committed
273
274
            """

Atream's avatar
Atream committed
275
            # flash attn doesn't support head_dim bigger than 256
Atream's avatar
Atream committed
276
            # use triton attention kernel adapted from vLLM and SGLang for MQA
Atream's avatar
Atream committed
277
            decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
278
                             page_table,
279
                             position_ids.squeeze(0).to(torch.int32)+1, attn_logits,
Atream's avatar
Atream committed
280
                             4, #num_kv_splits # follow vLLM, fix it TODO
281
282
283
                             self.softmax_scale,
                             past_key_value.page_size)
            
Atream's avatar
Atream committed
284
285
286
287
            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
            attn_output = attn_output.transpose(1, 2)
            attn_output = torch.matmul(attn_output, out_absorb.mT)
288
            attn_output = attn_output.transpose(1, 2)
Atream's avatar
Atream committed
289
            
290
291
292
293
294
295
296
297
298
299
            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
            attn_output = self.o_proj(attn_output)
            
            #print("attn_output", torch.isnan(attn_output).any())
            return attn_output, None, past_key_value
        else:
            if past_key_value is not None:
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
                k_pe.squeeze(0)
                compressed_kv.squeeze(0)
300
301
302
303
304
305
306
307
                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
                compressed_kv, k_pe = torch.split(
                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
                )
            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
            k_pe = k_pe[:, :kv_seq_len]
            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
            compressed_kv = compressed_kv[:, :kv_seq_len]
308
309
            kv = (
                self.kv_b_proj(compressed_kv)
310
                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
chenxl's avatar
chenxl committed
311
            )
312
            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
Atream's avatar
Atream committed
313
            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
314
315
316
            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

317
            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
318
            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
319
            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
320
            
321
            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
322
323
324
325
326
327
328
329
            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

            attn_output = flash_attn_func(
                query_states,
                key_states,
                value_states_padded,
                softmax_scale=self.softmax_scale,
                causal=True,
chenxl's avatar
chenxl committed
330
331
            )

332
333
334
335
336
337
338
339
            if self.q_head_dim != self.v_head_dim:
                attn_output = attn_output[:, :, :, : self.v_head_dim]

            attn_output = attn_output.reshape(
                bsz, q_len, self.num_heads * self.v_head_dim
            ).contiguous()
            attn_output = self.o_proj(attn_output)
            return attn_output, None, past_key_value
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370

    def forward_linux_flashinfer(
            self,
            hidden_states: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            past_key_value: Optional[Cache] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
            cache_position: Optional[torch.Tensor] = None,
            **kwargs,
        ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        bsz, q_len, _ = hidden_states.size()

        if self.q_lora_rank is None:
            q = self.q_proj(hidden_states)
        else:
            q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states)))
        q = q.view(bsz, q_len, self.num_heads, self.q_head_dim)
        q_nope, q_pe = torch.split(
            q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
        compressed_kv, k_pe = torch.split(
            compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
        )
        compressed_kv = self.kv_a_layernorm(compressed_kv)
        k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
        compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)
371
372
373
374
375

        kv_seq_len = q_len
        if past_key_value is not None:
            if self.layer_idx is None:
                raise ValueError(
Atream's avatar
Atream committed
376
                    f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
377
378
379
380
                    "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
                    "with a layer index."
                )
            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
381
382
383
384
385
386
        
        cos, sin = self.rotary_emb(q_pe, position_ids)
        q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
        # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
        
        # decode
387
        if q_len == 1 or self.absorb_for_prefill:
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            if past_key_value is not None:
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
                compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
                compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, past_key_value.page_size, self.kv_lora_rank)
                k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, past_key_value.page_size, self.qk_rope_head_dim)
                # k_pe [max_pages, page_size, self.qk_rope_head_dim]
                # compressed_kv [max_pages, page_size, self.kv_lora_rank]

            # q_nope [bsz, q_len, self.num_heads, self.qk_nope_head_dim]
            # q_absorb [self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank]
            q_absorb, out_absorb = self.get_absorbed()
            q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
            q_nope = torch.matmul(q_nope, q_absorb) # batched MM
            q_nope = q_nope.transpose(1, 2)
402
403
            q_nope = q_nope.contiguous()
            #assert q_nope.is_contiguous()
404
405
406
            
            # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
            # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
407
408
            q_nope.squeeze_(0)
            q_pe.squeeze_(0)
409
410
411
412

            # flash attn doesn't support head_dim bigger than 256, use flashinfer
            if self.mla_wrapper is None:
                self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
413
414
415
            if self.mla_wrapper.need_plan:
                self.mla_wrapper.need_plan = False
                if q_len == 1:
416
                    self.mla_wrapper.plan(None,None,None,
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                                        position_ids.squeeze(1)+1,
                                        self.num_heads,
                                        self.kv_lora_rank,
                                        self.qk_rope_head_dim,
                                        past_key_value.page_size,
                                        self.softmax_scale,
                                        q_nope.dtype,
                                        compressed_kv.dtype)
                else:
                    qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
                    kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
                    self.mla_wrapper.plan(qo_indptr,None,None,
                                        kv_len_arr,
                                        self.num_heads,
                                        self.kv_lora_rank,
                                        self.qk_rope_head_dim,
                                        past_key_value.page_size,
                                        self.softmax_scale,
                                        q_nope.dtype,
                                        compressed_kv.dtype)
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)

            """
            k = (
                torch.cat([compressed_kv, k_pe], dim=-1)
                .view(-1, 1, 512 + 64)
                .repeat_interleave(self.num_heads, dim=1)
            )
            v = compressed_kv.view(-1, 1, 512).repeat_interleave(self.num_heads, dim=1)
            lens = position_ids.item() + 1
            #print("lens", lens)
            attn_ref, lse_ref = attention_ref(
                1,
                torch.cat([q_nope, q_pe], dim=-1),
                k[:lens],
                v[:lens],
                False,
                self.softmax_scale
            )
            attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
            """
            
            # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
            # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
            # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
DDong Jianwei's avatar
tmp  
DDong Jianwei committed
462
463
            attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
            attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
464
            attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
465
            
DDong Jianwei's avatar
tmp  
DDong Jianwei committed
466
            attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
467
468
469
470
471
472
473
474
            attn_output = self.o_proj(attn_output)

            return attn_output, None, past_key_value
        else:
            if past_key_value is not None:
                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
                k_pe.squeeze(0)
                compressed_kv.squeeze(0)
475
476
477
478
479
480
481
482
                compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
                compressed_kv, k_pe = torch.split(
                    compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
                )
            k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
            k_pe = k_pe[:, :kv_seq_len]
            compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
            compressed_kv = compressed_kv[:, :kv_seq_len]
483
484
            kv = (
                self.kv_b_proj(compressed_kv)
485
                .view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
486
487
488
489
490
491
            )
            k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
            query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
            query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
            query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

492
            key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
493
            key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
494
            key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)
495
            
496
            value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
            value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

            attn_output = flash_attn_func(
                query_states,
                key_states,
                value_states_padded,
                softmax_scale=self.softmax_scale,
                causal=True,
            )

            if self.q_head_dim != self.v_head_dim:
                attn_output = attn_output[:, :, :, : self.v_head_dim]

            attn_output = attn_output.reshape(
                bsz, q_len, self.num_heads * self.v_head_dim
            ).contiguous()
            attn_output = self.o_proj(attn_output)
            return attn_output, None, past_key_value
Atream's avatar
Atream committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        
    def forward_windows(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if "padding_mask" in kwargs:
            warnings.warn(
                "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            )
        bsz, q_len, _ = hidden_states.size()
chenxl's avatar
chenxl committed
532

Atream's avatar
Atream committed
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        if q_len <= self.chunck_size:
            return self.forward_chunck(
                            hidden_states,
                            attention_mask,
                            position_ids,
                            past_key_value,
                            output_attentions,
                            use_cache,
                            cache_position,
                            **kwargs
                        )

        assert output_attentions == False, "output_attentions is not supported when using chunked attention"
        attn_output = None
        cur_idx = 0
        while cur_idx < q_len:
            if attention_mask is not None:
                chunk_mask = attention_mask[:, :, cur_idx:min(cur_idx + self.chunck_size, q_len), ...]
            else:
                # generate chunk_mask automatically.
                self.attn_mask = \
                    torch.zeros(1, 1, self.chunck_size, past_key_value.max_cache_len, device=hidden_states.device) \
                        if self.attn_mask is None \
                            else self.attn_mask
                self.attn_mask[:, :, :, cur_idx:min(cur_idx+self.chunck_size, past_key_value.max_cache_len)] = \
                    -1e+38 * torch.triu(torch.ones(self.chunck_size, self.chunck_size, device=hidden_states.device), diagonal=1)\
                        [:,:min(self.chunck_size, min(past_key_value.max_cache_len-cur_idx, self.chunck_size))]
                self.attn_mask[:, :, :, cur_idx+self.chunck_size:] = -1e+38
                self.attn_mask[:, :, :, :cur_idx] = 0
                chunk_mask = torch.narrow(self.attn_mask, 2, 0, min(self.chunck_size, q_len-cur_idx))

            cur_output, _, _ = self.forward_chunck(
                            hidden_states[:, cur_idx:min(cur_idx + self.chunck_size, q_len), ...],
                            chunk_mask,
                            position_ids[:, cur_idx:min(cur_idx + self.chunck_size, q_len)],
                            past_key_value,
                            output_attentions,
                            use_cache,
                            cache_position[cur_idx:min(cur_idx + self.chunck_size, q_len)],
                            **kwargs
                        )
            cur_idx += self.chunck_size
            if attn_output is None:
                attn_output = cur_output
            else:
                attn_output = torch.cat((attn_output, cur_output), dim=-2)
                
        return attn_output, None, past_key_value

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
593
594
        if os.name == 'nt' or get_compute_capability()<8:
            print("for Windows or GPU before ampere, use forward_windows")
Atream's avatar
Atream committed
595
596
597
598
599
600
601
602
603
604
605
            return self.forward_windows(
                hidden_states,
                attention_mask,
                position_ids,
                past_key_value,
                output_attentions,
                use_cache,
                cache_position,
                **kwargs,
            )
        else:
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
            if flashinfer_enabled:
                return self.forward_linux_flashinfer(
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_value,
                    output_attentions,
                    use_cache,
                    cache_position,
                    **kwargs,
                )
            else:
                return self.forward_linux_triton(
                    hidden_states,
                    attention_mask,
                    position_ids,
                    past_key_value,
                    output_attentions,
                    use_cache,
                    cache_position,
                    **kwargs,
                )
chenxl's avatar
chenxl committed
628
629
630
631
632
633
634
635
636
637


class KLlamaAttention(BaseInjectedModule):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self,
                 key: str,
                 gguf_loader : GGUFLoader,
                 config: PretrainedConfig,
                 orig_module: nn.Module,
638
639
                 prefill_device: str = "cuda",
                 generate_device: str = "cuda",
chenxl's avatar
chenxl committed
640
                 **kwargs):
641
        BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
chenxl's avatar
chenxl committed
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
        self.orig_module.__init__(orig_module.config,
            orig_module.layer_idx)
    def apply_rotary_pos_emb(self, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
        """Applies Rotary Position Embedding to the query and key tensors.

        Args:
            q (`torch.Tensor`): The query tensor.
            k (`torch.Tensor`): The key tensor.
            cos (`torch.Tensor`): The cosine part of the rotary embedding.
            sin (`torch.Tensor`): The sine part of the rotary embedding.
            position_ids (`torch.Tensor`, *optional*):
                Deprecated and unused.
            unsqueeze_dim (`int`, *optional*, defaults to 1):
                The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
                sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
                that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
                k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
                cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
                the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
        Returns:
            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
        """
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.45
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        if self.config.pretraining_tp > 1:
            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
            query_slices = self.q_proj.weight.split(
                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
            )
            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
            query_states = torch.cat(query_states, dim=-1)

            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
            key_states = torch.cat(key_states, dim=-1)

            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
            value_states = torch.cat(value_states, dim=-1)

        else:
            query_states = self.q_proj(hidden_states)
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        if position_embeddings is None:

            logger.warning(
                "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
                "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
                "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
                "removed and `position_embeddings` will be mandatory."
            )
            cos, sin = self.rotary_emb(value_states, position_ids)
        else:
            cos, sin = position_embeddings
        query_states, key_states = self.apply_rotary_pos_emb(query_states, key_states, cos, sin)
        if q_len == 1:
            position_ids = position_ids[0][-1].unsqueeze(0).unsqueeze(0)
            query_states = query_states[:, :, -1:]
            key_states = key_states[:, :, -1:]

        attn_output = KLlamaModel.dynamic_sdpa.apply(
            self.layer_idx,
            bsz,
            position_ids[0][0],
            query_states.transpose(1, 2).to(torch.float16),
            key_states.transpose(1, 2).to(torch.float16),
            value_states.transpose(1, 2).to(torch.float16),
            mode="prefill" if q_len > 1 else "generate",
        )


        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(bsz, q_len, -1)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

Atream's avatar
Atream committed
757
        return attn_output, attn_weights, past_key_value