Commit 4b485dd1 authored by wangkx1's avatar wangkx1
Browse files

init

parent 5c5de278
cp /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py.org /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py
\ No newline at end of file
cp /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py.org
cp src/flash_mla_interface_torch.py /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py
\ No newline at end of file
cp /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py.org
cp src/flash_mla_interface_triton.py /usr/local/lib/python3.10/dist-packages/flash_mla/flash_mla_interface.py
\ No newline at end of file
from typing import Optional, Tuple
import dataclasses
import torch
import flash_mla.cuda as flash_mla_cuda
import torch
import triton
import triton.language as tl
import math
from typing import Optional, Tuple
from dataclasses import dataclass
# Simple implementation of flash attention for gfx926
def flash_mla_with_kvcache_torch(
q: torch.Tensor, # batch_size x seqlen_q x num_heads_q x head_size_k
k_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_k
v_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_v
block_table: torch.Tensor, # batch_size x max_num_blocks_per_seq
cache_seqlens: torch.Tensor, # batch_size
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Implementation of flash attention with KV cache for gfx926 architecture
"""
# Check inputs
assert q.dim() == 4, "q must be 4-dimensional"
assert k_cache.dim() == 4, "k_cache must be 4-dimensional"
assert v_cache.dim() == 4, "v_cache must be 4-dimensional"
assert block_table.dim() == 2, "block_table must be 2-dimensional"
assert cache_seqlens.dim() == 1, "cache_seqlens must be 1-dimensional"
# Get dimensions
batch_size, seqlen_q, num_heads_q, head_size_k = q.shape
num_blocks, page_block_size, num_heads_k, _ = k_cache.shape
max_num_blocks_per_seq = block_table.shape[1]
# Check head dimensions
assert head_size_k == 576 or head_size_k == 512, "Only head_size_k == 576 or 512 is supported"
assert head_dim_v == 512, "Only head_size_v == 512 is supported"
assert num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"
assert page_block_size == 64, "Currently page_block_size must be 64"
# Set default softmax scale
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_size_k)
# Create output tensors
out = torch.empty((batch_size, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
lse = torch.empty((batch_size, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device)
# Use simplified implementation that works on all architectures
for b in range(batch_size):
seq_len_k = cache_seqlens[b].item()
# Get query for this batch
q_batch = q[b] # seqlen_q x num_heads_q x head_size_k
# Calculate attention scores using the provided k_cache and block_table
# For gfx926, we'll use a simplified approach
# Get the relevant blocks from the block table
num_k_blocks = (seq_len_k + page_block_size - 1) // page_block_size
blocks = block_table[b, :num_k_blocks].long()
# Ensure blocks are within bounds
blocks = blocks % num_blocks
# Gather the relevant key and value blocks
k = k_cache[blocks].reshape(-1, num_heads_k, head_size_k)[:seq_len_k]
v = v_cache[blocks].reshape(-1, num_heads_k, head_dim_v)[:seq_len_k]
# Handle NaN values
k[k != k] = 0.0
v[v != v] = 0.0
# Expand k and v if needed
if num_heads_k < num_heads_q:
k = k.repeat_interleave(num_heads_q // num_heads_k, dim=1)
v = v.repeat_interleave(num_heads_q // num_heads_k, dim=1)
# Calculate attention scores
# Reshape k for correct matrix multiplication
k_reshaped = k.permute(1, 0, 2) # num_heads_q x seq_len_k x head_size_k
scores = torch.einsum('qhd,hkd->qhk', q_batch, k_reshaped) # seqlen_q x num_heads_q x seq_len_k
scores *= softmax_scale
# Apply causal mask if needed
if causal and seqlen_q > 1:
mask = torch.ones(seqlen_q, seq_len_k, device=q.device, dtype=torch.bool)
mask = mask.tril(diagonal=seq_len_k - seqlen_q)
scores = scores.masked_fill(mask.logical_not().unsqueeze(1), -float('inf'))
# Apply softmax
max_scores = scores.max(dim=-1, keepdim=True)[0]
exp_scores = torch.exp(scores - max_scores)
sum_exp = exp_scores.sum(dim=-1, keepdim=True)
# Calculate lse
current_lse = torch.log(sum_exp.squeeze(-1)) + max_scores.squeeze(-1)
lse[b] = current_lse.transpose(0, 1)
# Calculate attention weights
attention = exp_scores / sum_exp
attention = attention.to(torch.float32)
# Calculate output
# Reshape v for correct matrix multiplication
v_reshaped = v.permute(1, 0, 2) # num_heads_q x seq_len_k x head_dim_v
v_reshaped = v_reshaped.to(torch.float32)
out[b] = torch.einsum('qhk,hkd->qhd', attention, v_reshaped) # seqlen_q x num_heads_q x head_dim_v
out[b] = out[b].to(q.dtype)
# Correct for q tokens which has no attendable k
lonely_q_mask = (current_lse == -float('inf'))
out[b][lonely_q_mask.unsqueeze(-1).broadcast_to(out[b].shape)] = 0.0
lse[b][lonely_q_mask.transpose(0, 1)] = float('inf')
return out, lse
def flash_mla_with_kvcache_torch_interface(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata=None,
num_splits=None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Wrapper function to match the original flash_mla interface
"""
# For dense attention (no indices provided)
if indices is None:
# Use the first head_dim_v dimensions of k_cache as v_cache
# This matches the reference implementation
head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v
v_cache = k_cache[..., :head_dim_v_int]
out, lse = flash_mla_with_kvcache_torch(
q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal
)
return out, lse
else:
# Sparse attention not implemented yet
raise NotImplementedError("Sparse attention is not implemented in Triton version")
@dataclasses.dataclass
class FlashMLASchedMeta:
"""
A class that stores the tile scheduler metadata of FlashMLA
"""
@dataclasses.dataclass
class Config:
b: int
s_q: int
h_q: int
page_block_size: int
h_k: int
causal: bool
is_fp8_kvcache: bool
topk: Optional[int]
extra_page_block_size: Optional[int]
extra_topk: Optional[int]
have_initialized: bool = False
config: Optional[Config] = None
tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32.
def get_mla_metadata(
*args,
**kwargs
) -> Tuple[FlashMLASchedMeta, None]:
"""
Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.
Arguments:
This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Return:
A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
"""
return FlashMLASchedMeta(), None
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool.
indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
where t is the k-th token of the j-th q-sequence in the i-th batch.
attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
head_dim should be 576 while head_dim_v should be 512.
In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
- The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
- First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
- Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
- Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
indices_in_kvcache = indices
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Sanity check. We only perform sanity check during the first invocation to save CPU time.
if indices_in_kvcache is not None:
assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
is_fp8_kvcache,
topk,
extra_k_page_block_size,
extra_topk,
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg
if topk is not None:
# Sparse attention
assert not causal, "causal must be False when sparse attention is enabled"
assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
head_dim_v, softmax_scale
)
else:
# Dense attention
assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits
# )
out, lse = flash_mla_with_kvcache_torch_interface(
q, k_cache, block_table, cache_seqlens, head_dim_v,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
softmax_scale, causal
)
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
return (out, lse)
def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
attn_sink: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
attn_sink: optional, [h_q], float32.
If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
+-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
This argument has no effect on lse and max_logits.
topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
Returns:
(output, max_logits, lse)
Please refer to tests/ref.py for the precise definitions of these parameters.
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, log-sum-exp of attention scores
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v, attn_sink, topk_length
)
return results
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
return out, softmax_lse
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits
)
return out, softmax_lse
def flash_mla_with_kvcache_quantization(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e5m2"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
def flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e5m2"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
# def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
# def flash_mla_with_kvcache_kvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
\ No newline at end of file
from typing import Optional, Tuple
import dataclasses
import torch
import flash_mla.cuda as flash_mla_cuda
import torch
import triton
import triton.language as tl
import math
from typing import Optional, Tuple
from dataclasses import dataclass
@triton.jit
def _flash_mla_kernel(
# 输入输出指针
q_ptr,
k_cache_ptr,
v_cache_ptr,
block_table_ptr,
cache_seqlens_ptr,
out_ptr,
lse_ptr,
# 形状参数
batch_size,
seqlen_q,
num_heads_q,
head_size_k,
head_dim_v,
num_blocks,
page_block_size,
num_heads_k,
max_num_blocks_per_seq,
# 步长
stride_q_batch,
stride_q_seq,
stride_q_head,
stride_q_dim,
stride_out_batch,
stride_out_seq,
stride_out_head,
stride_out_dim,
stride_lse_batch,
stride_lse_head,
stride_lse_seq,
stride_k_cache_block,
stride_k_cache_token,
stride_k_cache_head,
stride_k_cache_dim,
stride_v_cache_block,
stride_v_cache_token,
stride_v_cache_head,
stride_v_cache_dim,
stride_block_table_batch,
stride_block_table_block,
# 其他参数
softmax_scale,
causal,
BLOCK_SIZE: tl.constexpr,
HEAD_DIM_K: tl.constexpr, # 实际 head_size_k
HEAD_DIM_V: tl.constexpr, # 实际 head_dim_v
HEAD_DIM_K_PAD: tl.constexpr, # next_power_of_2(HEAD_DIM_K)
HEAD_DIM_V_PAD: tl.constexpr, # next_power_of_2(HEAD_DIM_V)
):
# # 当前 program 处理的 (batch, head, token)
# pid_b = tl.program_id(0)
# pid_h = tl.program_id(1)
# pid_t = tl.program_id(2)
# # 序列长度
# seq_len_k = tl.load(cache_seqlens_ptr + pid_b)
# if seq_len_k == 0:
# # 没有可 attend 的 KV,输出全 0,LSE 为 inf
# out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
# for d in range(HEAD_DIM_V):
# tl.store(out_ptr + out_offset + d * stride_out_dim, 0.0)
# lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
# tl.store(lse_ptr + lse_offset, float('inf'))
# return
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
pid_t = tl.program_id(2)
seq_len_k = tl.load(cache_seqlens_ptr + pid_b)
if seq_len_k == 0:
# 无 KV 可用
out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
out_offsets = tl.arange(0, HEAD_DIM_V)
out_ptrs = out_ptr + out_offset + out_offsets * stride_out_dim
tl.store(out_ptrs, 0.0, mask=out_offsets < HEAD_DIM_V)
lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
tl.store(lse_ptr + lse_offset, float('inf'))
return
num_heads_q_per_k = num_heads_q // num_heads_k
key_head_idx = pid_h // num_heads_q_per_k
q_offset = pid_b * stride_q_batch + pid_t * stride_q_seq + pid_h * stride_q_head
offs_dk = tl.arange(0, HEAD_DIM_K_PAD)
q_ptrs = q_ptr + q_offset + offs_dk * stride_q_dim
mask_dk = offs_dk < HEAD_DIM_K
q = tl.load(q_ptrs, mask=mask_dk, other=0.0)
m_i = -float('inf')
l_i = 0.0
o_i = tl.zeros([HEAD_DIM_V], dtype=tl.float32)
num_k_blocks = (seq_len_k + BLOCK_SIZE - 1) // BLOCK_SIZE
offset_causal = seq_len_k - seqlen_q # 注意 seqlen_q 是 kernel 参数
for block_idx in range(num_k_blocks):
physical_block = tl.load(block_table_ptr + pid_b * stride_block_table_batch + block_idx * stride_block_table_block)
if block_idx == num_k_blocks - 1:
cur_block_size = seq_len_k - block_idx * BLOCK_SIZE
if cur_block_size == 0:
cur_block_size = BLOCK_SIZE
else:
cur_block_size = BLOCK_SIZE
# 加载 K
k_block_ptr = k_cache_ptr + physical_block * stride_k_cache_block
offs_t = tl.arange(0, BLOCK_SIZE)
k_ptrs = (k_block_ptr + offs_t[:, None] * stride_k_cache_token +
key_head_idx * stride_k_cache_head + offs_dk[None, :] * stride_k_cache_dim)
mask_k = (offs_t[:, None] < cur_block_size) & mask_dk[None, :]
K_block = tl.load(k_ptrs, mask=mask_k, other=0.0)
# 加载 V
v_block_ptr = v_cache_ptr + physical_block * stride_v_cache_block
offs_dv = tl.arange(0, HEAD_DIM_V)
v_ptrs = (v_block_ptr + offs_t[:, None] * stride_v_cache_token +
key_head_idx * stride_v_cache_head + offs_dv[None, :] * stride_v_cache_dim)
mask_v = offs_t[:, None] < cur_block_size
V_block = tl.load(v_ptrs, mask=mask_v, other=0.0)
# 计算分数
scores = tl.sum(K_block * q[None, :], axis=1) * softmax_scale
if causal:
kv_pos = block_idx * BLOCK_SIZE + offs_t
causal_mask = (kv_pos <= pid_t + offset_causal)
mask = (offs_t < cur_block_size) & causal_mask
scores = tl.where(mask, scores, -float('inf'))
m_block = tl.max(scores, axis=0)
# 只有当 block 至少有一个有效 token 时才更新状态
if m_block != -float('inf'):
m_new = tl.maximum(m_i, m_block)
exp_scores = tl.exp(scores - m_new)
l_i = l_i * tl.exp(m_i - m_new) + tl.sum(exp_scores, axis=0)
o_i = o_i * tl.exp(m_i - m_new) + tl.sum(exp_scores[:, None] * V_block, axis=0)
m_i = m_new
# 最终输出
if l_i == 0.0:
out_val = tl.zeros([HEAD_DIM_V], dtype=tl.float32)
lse_val = float('inf')
else:
out_val = o_i / l_i
lse_val = m_i + tl.log(l_i)
# 写回
out_offset = pid_b * stride_out_batch + pid_t * stride_out_seq + pid_h * stride_out_head
out_offsets = tl.arange(0, HEAD_DIM_V)
out_ptrs = out_ptr + out_offset + out_offsets * stride_out_dim
tl.store(out_ptrs, out_val, mask=out_offsets < HEAD_DIM_V)
lse_offset = pid_b * stride_lse_batch + pid_h * stride_lse_head + pid_t * stride_lse_seq
tl.store(lse_ptr + lse_offset, lse_val)
kernels = {}
def flash_mla_with_kvcache_triton(
q: torch.Tensor, # [batch, seqlen_q, num_heads_q, head_size_k]
k_cache: torch.Tensor, # [num_blocks, page_block_size, num_heads_k, head_size_k]
v_cache: torch.Tensor, # [num_blocks, page_block_size, num_heads_k, head_dim_v]
block_table: torch.Tensor, # [batch, max_num_blocks_per_seq]
cache_seqlens: torch.Tensor, # [batch]
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Triton 实现的 Flash MLA 注意力,支持分页 KV cache 和因果掩码。
返回 (output, lse)
"""
# 维度断言
assert q.dim() == 4, "q must be 4D"
assert k_cache.dim() == 4, "k_cache must be 4D"
assert v_cache.dim() == 4, "v_cache must be 4D"
assert block_table.dim() == 2, "block_table must be 2D"
assert cache_seqlens.dim() == 1, "cache_seqlens must be 1D"
assert k_cache.shape[2] == v_cache.shape[2], "num_heads_k mismatch"
assert k_cache.shape[3] == q.shape[3], "head_size_k mismatch"
assert v_cache.shape[3] == head_dim_v, "head_dim_v mismatch"
batch, seqlen_q, num_heads_q, head_size_k = q.shape
num_blocks, page_block_size, num_heads_k, _ = k_cache.shape
max_num_blocks_per_seq = block_table.shape[1]
assert page_block_size == 64, "Only page_block_size=64 is supported"
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_size_k)
# 计算填充后的维度(2 的幂)
HEAD_DIM_K_PAD = triton.next_power_of_2(head_size_k)
HEAD_DIM_V_PAD = triton.next_power_of_2(head_dim_v)
# 输出张量
out = torch.empty((batch, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
lse = torch.empty((batch, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device)
# 步长
stride_q_batch = q.stride(0)
stride_q_seq = q.stride(1)
stride_q_head = q.stride(2)
stride_q_dim = q.stride(3)
stride_out_batch = out.stride(0)
stride_out_seq = out.stride(1)
stride_out_head = out.stride(2)
stride_out_dim = out.stride(3)
stride_lse_batch = lse.stride(0)
stride_lse_head = lse.stride(1)
stride_lse_seq = lse.stride(2)
stride_k_cache_block = k_cache.stride(0)
stride_k_cache_token = k_cache.stride(1)
stride_k_cache_head = k_cache.stride(2)
stride_k_cache_dim = k_cache.stride(3)
stride_v_cache_block = v_cache.stride(0)
stride_v_cache_token = v_cache.stride(1)
stride_v_cache_head = v_cache.stride(2)
stride_v_cache_dim = v_cache.stride(3)
stride_block_table_batch = block_table.stride(0)
stride_block_table_block = block_table.stride(1)
BLOCK_SIZE = page_block_size # 64
num_warps = max(1, (head_dim_v + 31) // 32)
num_stages = 4
grid = (batch, num_heads_q, seqlen_q)
kernel, num_programs = kernels.get(BLOCK_SIZE, (None, 0))
if kernel is None:
# 编译并调用 kernel(传递填充后的维度作为 constexpr)
# 编译并调用 kernel(传递填充后的维度作为 constexpr,使用位置参数)
kernel = _flash_mla_kernel.warmup(
q, k_cache, v_cache, block_table, cache_seqlens, out, lse,
batch, seqlen_q, num_heads_q, head_size_k, head_dim_v,
num_blocks, page_block_size, num_heads_k, max_num_blocks_per_seq,
stride_q_batch, stride_q_seq, stride_q_head, stride_q_dim,
stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim,
stride_lse_batch, stride_lse_head, stride_lse_seq,
stride_k_cache_block, stride_k_cache_token, stride_k_cache_head, stride_k_cache_dim,
stride_v_cache_block, stride_v_cache_token, stride_v_cache_head, stride_v_cache_dim,
stride_block_table_batch, stride_block_table_block,
softmax_scale, causal,
BLOCK_SIZE, head_size_k, head_dim_v, HEAD_DIM_K_PAD, HEAD_DIM_V_PAD,
num_warps=num_warps, num_stages=num_stages, grid=(1,)
)
kernels[BLOCK_SIZE] = (kernel, num_programs)
kernel[(grid)](
q, k_cache, v_cache, block_table, cache_seqlens, out, lse,
batch, seqlen_q, num_heads_q, head_size_k, head_dim_v,
num_blocks, page_block_size, num_heads_k, max_num_blocks_per_seq,
stride_q_batch, stride_q_seq, stride_q_head, stride_q_dim,
stride_out_batch, stride_out_seq, stride_out_head, stride_out_dim,
stride_lse_batch, stride_lse_head, stride_lse_seq,
stride_k_cache_block, stride_k_cache_token, stride_k_cache_head, stride_k_cache_dim,
stride_v_cache_block, stride_v_cache_token, stride_v_cache_head, stride_v_cache_dim,
stride_block_table_batch, stride_block_table_block,
softmax_scale, causal,
BLOCK_SIZE, head_size_k, head_dim_v, HEAD_DIM_K_PAD, HEAD_DIM_V_PAD,
)
return out, lse
# 以下 wrapper 和 get_mla_metadata 保持不变
def flash_mla_with_kvcache_triton_interface(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata=None,
num_splits=None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if indices is None:
head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v
v_cache = k_cache[..., :head_dim_v_int]
out, lse = flash_mla_with_kvcache_triton(
q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal
)
return out, lse
else:
raise NotImplementedError("Sparse attention is not implemented in Triton version")
@dataclasses.dataclass
class FlashMLASchedMeta:
"""
A class that stores the tile scheduler metadata of FlashMLA
"""
@dataclasses.dataclass
class Config:
b: int
s_q: int
h_q: int
page_block_size: int
h_k: int
causal: bool
is_fp8_kvcache: bool
topk: Optional[int]
extra_page_block_size: Optional[int]
extra_topk: Optional[int]
have_initialized: bool = False
config: Optional[Config] = None
tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32.
def get_mla_metadata(
*args,
**kwargs
) -> Tuple[FlashMLASchedMeta, None]:
"""
Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache.
Arguments:
This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface.
Return:
A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful.
"""
return FlashMLASchedMeta(), None
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool.
indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled.
Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block),
where t is the k-th token of the j-th q-sequence in the i-th batch.
attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0.
extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively.
topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking.
For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2:
head_dim should be 576 while head_dim_v should be 512.
In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as:
- The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1.
- First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
- Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
- Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
indices_in_kvcache = indices
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None
extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None
extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Sanity check. We only perform sanity check during the first invocation to save CPU time.
if indices_in_kvcache is not None:
assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)"
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
is_fp8_kvcache,
topk,
extra_k_page_block_size,
extra_topk,
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg
assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg
assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg
if topk is not None:
# Sparse attention
assert not causal, "causal must be False when sparse attention is enabled"
assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled"
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd(
q, k_cache, indices_in_kvcache, topk_length, attn_sink,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
extra_k_cache, extra_indices_in_kvcache, extra_topk_length,
head_dim_v, softmax_scale
)
else:
# Dense attention
assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits
# )
out, lse = flash_mla_with_kvcache_triton_interface(
q, k_cache, block_table, cache_seqlens, head_dim_v,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
softmax_scale, causal
)
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
return (out, lse)
def flash_mla_sparse_fwd(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
attn_sink: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
Args:
q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv
sm_scale: float
d_v: The dimension of value vectors. Can only be 512
attn_sink: optional, [h_q], float32.
If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)).
+-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros).
This argument has no effect on lse and max_logits.
topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices).
In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation.
Returns:
(output, max_logits, lse)
Please refer to tests/ref.py for the precise definitions of these parameters.
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, log-sum-exp of attention scores
"""
results = flash_mla_cuda.sparse_prefill_fwd(
q, kv, indices, sm_scale, d_v, attn_sink, topk_length
)
return results
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
return out, softmax_lse
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits
)
return out, softmax_lse
def flash_mla_with_kvcache_quantization(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e5m2"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
def flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e5m2"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
# def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
# def flash_mla_with_kvcache_kvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment