flash_mla_triton.py 6.47 KB
Newer Older
wangkaixiong's avatar
init  
wangkaixiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import triton
import triton.language as tl
import math
from typing import Optional, Tuple

# Simple implementation of flash attention for gfx926
def flash_mla_with_kvcache_triton(
    q: torch.Tensor,  # batch_size x seqlen_q x num_heads_q x head_size_k
    k_cache: torch.Tensor,  # num_blocks x page_block_size x num_heads_k x head_size_k
    v_cache: torch.Tensor,  # num_blocks x page_block_size x num_heads_k x head_size_v
    block_table: torch.Tensor,  # batch_size x max_num_blocks_per_seq
    cache_seqlens: torch.Tensor,  # batch_size
    head_dim_v: int,
    softmax_scale: Optional[float] = None,
    causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Implementation of flash attention with KV cache for gfx926 architecture
    """
    # Check inputs
    assert q.dim() == 4, "q must be 4-dimensional"
    assert k_cache.dim() == 4, "k_cache must be 4-dimensional"
    assert v_cache.dim() == 4, "v_cache must be 4-dimensional"
    assert block_table.dim() == 2, "block_table must be 2-dimensional"
    assert cache_seqlens.dim() == 1, "cache_seqlens must be 1-dimensional"
    
    # Get dimensions
    batch_size, seqlen_q, num_heads_q, head_size_k = q.shape
    num_blocks, page_block_size, num_heads_k, _ = k_cache.shape
    max_num_blocks_per_seq = block_table.shape[1]
    
    # Check head dimensions
    assert head_size_k == 576 or head_size_k == 512, "Only head_size_k == 576 or 512 is supported"
    assert head_dim_v == 512, "Only head_size_v == 512 is supported"
    assert num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"
    assert page_block_size == 64, "Currently page_block_size must be 64"
    
    # Set default softmax scale
    if softmax_scale is None:
        softmax_scale = 1.0 / math.sqrt(head_size_k)
    
    # Create output tensors
    out = torch.empty((batch_size, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
    lse = torch.empty((batch_size, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device)
    
    # Use simplified implementation that works on all architectures
    for b in range(batch_size):
        seq_len_k = cache_seqlens[b].item()
        
        # Get query for this batch
        q_batch = q[b]  # seqlen_q x num_heads_q x head_size_k
        
        # Calculate attention scores using the provided k_cache and block_table
        # For gfx926, we'll use a simplified approach
        
        # Get the relevant blocks from the block table
        num_k_blocks = (seq_len_k + page_block_size - 1) // page_block_size
        blocks = block_table[b, :num_k_blocks].long()
        
        # Ensure blocks are within bounds
        blocks = blocks % num_blocks
        
        # Gather the relevant key and value blocks
        k = k_cache[blocks].reshape(-1, num_heads_k, head_size_k)[:seq_len_k]
        v = v_cache[blocks].reshape(-1, num_heads_k, head_dim_v)[:seq_len_k]
        
        # Handle NaN values
        k[k != k] = 0.0
        v[v != v] = 0.0
        
        # Expand k and v if needed
        if num_heads_k < num_heads_q:
            k = k.repeat_interleave(num_heads_q // num_heads_k, dim=1)
            v = v.repeat_interleave(num_heads_q // num_heads_k, dim=1)
        
        # Calculate attention scores
        # Reshape k for correct matrix multiplication
        k_reshaped = k.permute(1, 0, 2)  # num_heads_q x seq_len_k x head_size_k
        scores = torch.einsum('qhd,hkd->qhk', q_batch, k_reshaped)  # seqlen_q x num_heads_q x seq_len_k
        scores *= softmax_scale
        
        # Apply causal mask if needed
        if causal and seqlen_q > 1:
            mask = torch.ones(seqlen_q, seq_len_k, device=q.device, dtype=torch.bool)
            mask = mask.tril(diagonal=seq_len_k - seqlen_q)
            scores = scores.masked_fill(mask.logical_not().unsqueeze(1), -float('inf'))
        
        # Apply softmax
        max_scores = scores.max(dim=-1, keepdim=True)[0]
        exp_scores = torch.exp(scores - max_scores)
        sum_exp = exp_scores.sum(dim=-1, keepdim=True)
        
        # Calculate lse
        current_lse = torch.log(sum_exp.squeeze(-1)) + max_scores.squeeze(-1)
        lse[b] = current_lse.transpose(0, 1)
        
        # Calculate attention weights
        attention = exp_scores / sum_exp
        attention = attention.to(torch.float32)
        
        # Calculate output
        # Reshape v for correct matrix multiplication
        v_reshaped = v.permute(1, 0, 2)  # num_heads_q x seq_len_k x head_dim_v
        v_reshaped = v_reshaped.to(torch.float32)
        out[b] = torch.einsum('qhk,hkd->qhd', attention, v_reshaped)  # seqlen_q x num_heads_q x head_dim_v
        out[b] = out[b].to(q.dtype)
        
        # Correct for q tokens which has no attendable k
        lonely_q_mask = (current_lse == -float('inf'))
        out[b][lonely_q_mask.unsqueeze(-1).broadcast_to(out[b].shape)] = 0.0
        lse[b][lonely_q_mask.transpose(0, 1)] = float('inf')
    
    return out, lse

def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata=None,
    num_splits=None,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    is_fp8_kvcache: bool = False,
    indices: Optional[torch.Tensor] = None,
    attn_sink: Optional[torch.Tensor] = None,
    extra_k_cache: Optional[torch.Tensor] = None,
    extra_indices_in_kvcache: Optional[torch.Tensor] = None,
    topk_length: Optional[torch.Tensor] = None,
    extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Wrapper function to match the original flash_mla interface
    """
    # For dense attention (no indices provided)
    if indices is None:
        # Use the first head_dim_v dimensions of k_cache as v_cache
        # This matches the reference implementation
        head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v
        v_cache = k_cache[..., :head_dim_v_int]
        
        out, lse = flash_mla_with_kvcache_triton(
            q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal
        )
        
        return out, lse
    else:
        # Sparse attention not implemented yet
        raise NotImplementedError("Sparse attention is not implemented in Triton version")

def get_mla_metadata(*args, **kwargs) -> Tuple[dict, None]:
    """
    Returns a dummy metadata object to match the original interface
    """
    return {}, None