at::Tensor&q,// batch_size x seqlen_q x num_heads x head_size
constat::Tensor&kcache,// num_blocks x page_block_size x num_heads_k x head_size (when is_fp8 is False) or num_blocks x num_heads_k x (page_block_size*656) (when is_fp8 is True)
constinthead_size_v,
constat::Tensor&seqlens_k,// batch_size
constat::Tensor&block_table,// batch_size x max_num_blocks_per_seq
constfloatsoftmax_scale,
boolis_causal,
std::optional<at::Tensor>&tile_scheduler_metadata,// num_sm_parts x (DecodingSchedMetaSize/4)
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
assertisinstance(sched_meta,FlashMLASchedMeta),"tile_scheduler_metadata must be of type FlashMLASchedMeta"
assertnum_splitsisNone,"num_splits must be None"
ifsoftmax_scaleisNone:
softmax_scale=q.shape[-1]**(-0.5)
ifnotsched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized=True
sched_meta.config=FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
causal,
False,
0,
0,
0
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg=" Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assertsched_meta.configisnotNone
assertsched_meta.config.b==q.shape[0],"sched_meta.config.b must be equal to batch_size."+helper_msg
assertsched_meta.config.s_q==q.shape[1],"sched_meta.config.s_q must be equal to seq_len_q."+helper_msg
assertsched_meta.config.h_q==q.shape[2],"sched_meta.config.h_q must be equal to num_heads_q."+helper_msg
assertsched_meta.config.page_block_size==k_cache.shape[1],"sched_meta.config.page_block_size must be equal to page_block_size."+helper_msg
assertsched_meta.config.h_k==k_cache.shape[2],"sched_meta.config.h_k must be equal to num_heads_k."+helper_msg
assertsched_meta.config.causal==causal,"sched_meta.config.causal must be equal to causal."+helper_msg
assertsched_meta.config.is_fp8_kvcache==False,"sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache."+helper_msg
# Dense attention
assertblock_tableisnotNoneandcache_seqlensisnotNone,"block_table and cache_seqlens must be provided when dense attention is used."