block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# Check whether the input arguments are consistent with sched_meta
helper_msg=" Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assertsched_meta.configisnotNone
assertsched_meta.config.b==q.shape[0],"sched_meta.config.b must be equal to batch_size."+helper_msg
assertsched_meta.config.s_q==q.shape[1],"sched_meta.config.s_q must be equal to seq_len_q."+helper_msg
assertsched_meta.config.h_q==q.shape[2],"sched_meta.config.h_q must be equal to num_heads_q."+helper_msg
assertsched_meta.config.page_block_size==k_cache.shape[1],"sched_meta.config.page_block_size must be equal to page_block_size."+helper_msg
assertsched_meta.config.h_k==k_cache.shape[2],"sched_meta.config.h_k must be equal to num_heads_k."+helper_msg
assertsched_meta.config.causal==causal,"sched_meta.config.causal must be equal to causal."+helper_msg
assertsched_meta.config.is_fp8_kvcache==False,"sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache."+helper_msg
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# Check whether the input arguments are consistent with sched_meta
helper_msg=" Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assertsched_meta.configisnotNone
assertsched_meta.config.b==q.shape[0],"sched_meta.config.b must be equal to batch_size."+helper_msg
assertsched_meta.config.s_q==q.shape[1],"sched_meta.config.s_q must be equal to seq_len_q."+helper_msg
assertsched_meta.config.h_q==q.shape[2],"sched_meta.config.h_q must be equal to num_heads_q."+helper_msg
assertsched_meta.config.page_block_size==k_cache.shape[1],"sched_meta.config.page_block_size must be equal to page_block_size."+helper_msg
assertsched_meta.config.h_k==k_cache.shape[2],"sched_meta.config.h_k must be equal to num_heads_k."+helper_msg
assertsched_meta.config.causal==causal,"sched_meta.config.causal must be equal to causal."+helper_msg
assertsched_meta.config.is_fp8_kvcache==False,"sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache."+helper_msg
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."