Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
The KV cache must be contiguously valid for sparse attention on all arch. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."