cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
num_splits: optional override for BF16 sparse decode. Other paths keep using sched_meta.
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
is_fp8_kvcache: bool.
...
...
@@ -104,7 +104,6 @@ def flash_mla_with_kvcache(
sched_meta=tile_scheduler_metadata
indices_in_kvcache=indices
assertisinstance(sched_meta,FlashMLASchedMeta),"tile_scheduler_metadata must be of type FlashMLASchedMeta"
assertnum_splitsisNone,"num_splits override is only supported by BF16 sparse decode"
assertindices_in_kvcacheisNoneandattn_sinkisNoneandextra_k_cacheisNoneandextra_indices_in_kvcacheisNoneandtopk_lengthisNoneandextra_topk_lengthisNone,"indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used."
assertblock_tableisnotNoneandcache_seqlensisnotNone,"block_table and cache_seqlens must be provided when dense attention is used."