Commit 76bb5d10 authored by zhanghj2's avatar zhanghj2
Browse files

delete dead code

parent a206ecac
Pipeline #3735 failed with stages
in 0 seconds
......@@ -33,15 +33,6 @@ struct Arch {
archName = device_prop->gcnArchName;
}
bool is_sm90a() const {
return major == 9 && minor == 3;
// return major == 9 && minor == 0;
}
bool is_sm100f() const {
return major == 10;
}
bool is_gfx938() const {
return archName.substr(0, archName.find(':')) == "gfx938";
}
......
......@@ -291,7 +291,7 @@ sparse_attn_decode_interface(
}
DecodeImplBase* impl;
if (arch.is_sm90a()) {
if (arch.is_gfx93x()) {
impl = new Decode_Sm90_Impl();
} else {
TORCH_CHECK(false, "Unsupported architecture for sparse decode fwd");
......
......@@ -73,7 +73,7 @@ def flash_mla_with_kvcache(
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details.
The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
The KV cache must be contiguously valid for sparse attention on all arch. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks.
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
......@@ -487,163 +487,3 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
)
return out, softmax_lse
# def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
# def flash_mla_with_kvcache_kvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment