Commit c353b35b authored by zhanghj2's avatar zhanghj2
Browse files

恢复支持旧接口

parent c566af36
......@@ -210,93 +210,276 @@ def flash_mla_sparse_fwd(
)
return results
def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor,
num_heads_per_head_k: int,
num_heads_k: int,
num_heads_q : int = 16,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
cache_seqlens: (batch_size), dtype torch.int32.
num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
num_heads_k: num_heads_k.
Returns:
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
num_splits: (batch_size + 1), dtype torch.int32.
"""
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k, num_heads_q)
def flash_mla_with_kvcache_qkvfp8(
def flash_mla_with_kvcache_quantization(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
if not sched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
False,
0,
0,
0
tile_scheduler_metadata,
num_splits
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
return out, softmax_lse
def flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
# Dense attention
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
descale_q, descale_k
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
return out, softmax_lse
def flash_mla_with_kvcache_kvfp8(
def flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: FlashMLASchedMeta,
num_splits: None = None,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
......@@ -306,63 +489,244 @@ def flash_mla_with_kvcache_kvfp8(
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
head_dim_v: Head_dim of v. Must be 512
sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
num_splits_placeholder: must be "None" (to be compatible with the old interface).
softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
causal: bool. Whether to apply causal attention mask. Only valid for dense attention
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
sched_meta = tile_scheduler_metadata
assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
assert num_splits is None, "num_splits must be None"
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
if not sched_meta.have_initialized:
# Initialize the tile scheduler metadata during the first invocation.
sched_meta.have_initialized = True
sched_meta.config = FlashMLASchedMeta.Config(
q.shape[0],
q.shape[1],
q.shape[2],
k_cache.shape[1],
k_cache.shape[2],
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
False,
0,
0,
0
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
else:
# Check whether the input arguments are consistent with sched_meta
helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
assert sched_meta.config is not None
assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Dense attention
assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
q, k_cache, head_dim_v,
cache_seqlens, block_table,
softmax_scale, causal,
sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
descale_q, descale_k
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
descale_q,
descale_k
)
sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
sched_meta.num_splits = new_num_splits
return (out, lse)
\ No newline at end of file
return out, softmax_lse
# def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_qkvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
# def flash_mla_with_kvcache_kvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
# block_table: Optional[torch.Tensor],
# cache_seqlens: Optional[torch.Tensor],
# head_dim_v: int,
# tile_scheduler_metadata: FlashMLASchedMeta,
# num_splits: None = None,
# softmax_scale: Optional[float] = None,
# causal: bool = False,
# descale_q: Optional[torch.Tensor] = None,
# descale_k: Optional[torch.Tensor] = None,
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# """
# Arguments:
# q: (batch_size, seq_len_q, num_heads_q, head_dim).
# k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
# block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used.
# cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used.
# head_dim_v: Head_dim of v. Must be 512
# sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same.
# num_splits_placeholder: must be "None" (to be compatible with the old interface).
# softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k).
# causal: bool. Whether to apply causal attention mask. Only valid for dense attention
# descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
# descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
# Return:
# out: (batch_size, seq_len_q, num_heads_q, head_dim_v), only support bf16 output
# softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
# """
# sched_meta = tile_scheduler_metadata
# assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta"
# assert num_splits is None, "num_splits must be None"
# if softmax_scale is None:
# softmax_scale = q.shape[-1] ** (-0.5)
# if not sched_meta.have_initialized:
# # Initialize the tile scheduler metadata during the first invocation.
# sched_meta.have_initialized = True
# sched_meta.config = FlashMLASchedMeta.Config(
# q.shape[0],
# q.shape[1],
# q.shape[2],
# k_cache.shape[1],
# k_cache.shape[2],
# causal,
# False,
# 0,
# 0,
# 0
# )
# else:
# # Check whether the input arguments are consistent with sched_meta
# helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta."
# assert sched_meta.config is not None
# assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg
# assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg
# assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg
# assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg
# assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg
# assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg
# assert sched_meta.config.is_fp8_kvcache == False, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg
# # Dense attention
# assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used."
# out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd_kvfp8(
# q, k_cache, head_dim_v,
# cache_seqlens, block_table,
# softmax_scale, causal,
# sched_meta.tile_scheduler_metadata, sched_meta.num_splits,
# descale_q, descale_k
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
\ No newline at end of file
......@@ -85,6 +85,12 @@ ext_modules.append(
"csrc/gfx93/prefill/sparse/instantiations/phase1_k576.cu",
"csrc/gfx93/prefill/sparse/instantiations/phase1_k576_topklen.cu",
"csrc/extension/flash_fwd_mla_bf16_gfx936.cu",
"csrc/extension/flash_fwd_mla_fp16_gfx936.cu",
"csrc/extension/flash_fwd_mla_fp8_gfx938.cu",
"csrc/extension/flash_fwd_mla_fp8_qbf16_gfx938.cu",
"csrc/extension/flash_fwd_mla_metadata.cu",
],
extra_compile_args={
"cxx": cxx_args + get_features_args(),
......@@ -98,7 +104,9 @@ ext_modules.append(
"-Rpass-analysis=kernel-resource-usage",
"-DDCU_ASM",
"--save-temps",
"-w"
"-w",
"-mllvm",
"-enable-num-vgprs-512=true",
] + get_features_args() + get_arch_flags()
},
include_dirs=[
......@@ -134,5 +142,6 @@ setup(
version=get_version(ROCM_HOME),
packages=find_packages(include=['flash_mla']),
ext_modules=ext_modules,
package_data={"flash_mla":["asm/*.co"]},
cmdclass={"build_ext": BuildExtension},
)
......@@ -5,7 +5,7 @@ import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache_kvfp8, get_mla_metadata
from flash_mla import flash_mla_with_kvcache_quantization, get_mla_decoding_metadata_dense_fp8
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
......@@ -14,9 +14,6 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False,
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
# tmp = query @ key.transpose(-2, -1)
# print("tmp ", tmp.shape, tmp[0, 0, :16])
# print("tmp ", tmp.shape, tmp[0, 0, 16:32])
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
......@@ -65,18 +62,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d))
# blocked_k = (torch.zeros(block_table.numel(), block_size, h_kv, d))
# blocked_k[:, 1:, :, :] = 0
# blocked_k[:, :, :, 1:] = 0
# blocked_k[0, 0:16, 0, 0] = 0
# blocked_k[0, 32:, 0, 0] = 0
# blocked_k[0, 0, 0, 1] = 2
# blocked_k[0, 0, 0, 2] = 3
# blocked_k[0, 0, 0, 3] = 4
# print(" blocked_k ", blocked_k[0, 0, 0, :])
blocked_k = blocked_k.to(torch.float8_e5m2)
# blocked_k = (torch.ones(block_table.numel(), block_size, h_kv, d)).to(torch.float8_e5m2)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.half).to(torch.float8_e5m2)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
......@@ -92,7 +78,9 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata()
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
......@@ -106,10 +94,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
def flash_mla():
return flash_mla_with_kvcache_kvfp8(
return flash_mla_with_kvcache_quantization(
q,
blocked_k,
block_table,
......@@ -118,8 +104,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q = descale_q,
descale_k = descale_k,
k_scale = k_scale,
kv_cache_dtype = "fp8_e5m2"
)
def ref_mla():
......@@ -148,9 +134,9 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# print("lse_flash ", lse_flash[0, 0, 0:10])
# print("lse_torch ", lse_torch[0, 0, 0:10])
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" out ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print("out max_diff ", (out_flash - out_torch).abs().max())
# print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" out ", torch.nonzero((out_flash - out_torch).abs()))
# print(" out_torch", out_torch)
cal_diff(lse_flash, lse_torch, "lse")
cal_diff(out_flash, out_torch, "out")
......@@ -220,12 +206,6 @@ def main(torch_dtype, is_prof=False):
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [1]:
# for s in [64]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
......
......@@ -4,9 +4,8 @@ import random
import torch
import triton
import kernelkit as kk
from flash_mla import flash_mla_with_kvcache_qkvfp8, get_mla_metadata
from flash_mla import flash_mla_with_kvcache_fp8, get_mla_decoding_metadata_dense_fp8
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
......@@ -88,7 +87,9 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata()
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q
)
init_dtype = q.dtype
def prepare_fp8_input():
......@@ -116,7 +117,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
# print(" descale_q ", descale_q.shape, descale_q.stride())
# print(" blocked_k ", blocked_k.shape)
def flash_mla():
return flash_mla_with_kvcache_qkvfp8(
return flash_mla_with_kvcache_fp8(
q,
blocked_k,
block_table,
......@@ -159,7 +160,10 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
# print(" diff ", torch.nonzero((lse_flash - lse_torch).abs() > 0.1))
# print(" diff ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print(" out_torch ", out_torch[0, 0, 0, 0:10])
# print(" out_flash ", out_flash[0, 0, 0, 0:10])
# print(" lse_flash ", lse_flash[0, 0:3, :1])
# print(" lse_torch ", lse_torch[0, 0:3, :1])
# print(" nan ", torch.nonzero(torch.isnan(out_flash)))
cal_diff(out_flash, out_torch, "out", use_fp8)
cal_diff(lse_flash, lse_torch, "lse")
......@@ -171,6 +175,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
......@@ -210,14 +216,14 @@ def main(torch_dtype, is_prof=False):
# 压测
for b in [3, 6, 9, 12, 15, 18, 21, 24, 40, 41, 79, 80]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 12345, 45321]:
for h_q in [16]:
for h_q in [32, 16, 64, 128]:
for s_q in [1, 2, 3]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for h_q in [32, 16, 64, 128]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
......
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache_fp8_with_cat, get_mla_decoding_metadata_dense_fp8
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
tmp = query @ key.transpose(-2, -1)
# print("tmp s ", tmp[0, :4, :10])
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str, use_fp8: bool=False) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
if use_fp8:
assert cos_diff < 1e-3
else:
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False,torch_dtype=torch.float16, is_q_bf16=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}"
)
# torch.cuda.empty_cache()
use_fp8 = torch_dtype == torch.float8_e4m3fn
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
# q = torch.ones(b, s_q, h_q, d)
q = torch.randn(b, s_q, h_q, d)
# for i in range(576):
# q[:, :, :, i] = i
# q[:, :, 1:, :] = 0
# q = torch.ones(b, s_q, h_q, d)
# print("q ", q[0, 0, 0:3, :10])
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.ones(block_table.numel(), block_size, h_kv, d)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
# blocked_k[:, :, :, 32:] = 0.0
# blocked_k[:, 32:, :, :] = 0
# blocked_k[:, :, :, 4:] = 0
# blocked_k[:, :32, :, :] = 0
# blocked_k[:, 16:, :, :] = 0
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q
)
init_dtype = q.dtype
def prepare_fp8_input():
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = None, None, None, None, None
if use_fp8:
nonlocal q, blocked_k, blocked_v
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
q_fp8 = q.to(fp8_dtype)
blocked_k_fp8 = blocked_k.to(fp8_dtype)
blocked_v_fp8 = blocked_k_fp8[..., :dv]
return q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k
q_fp8, blocked_k_fp8, blocked_v_fp8, descale_q, descale_k = prepare_fp8_input()
# print(blocked_v_fp8[0, 32:36, 0, :4])
if use_fp8:
q = q_fp8
blocked_k = blocked_k_fp8
blocked_v = blocked_v_fp8
# print(" descale_q ", descale_q.shape, descale_q.stride())
# print(" blocked_k ", blocked_k.shape)
q_nope = q[:, :, :, :512].contiguous()
q_pe = q[:, :, :, 512:].contiguous()
if is_q_bf16:
q_nope = q_nope.to(torch.bfloat16).contiguous()
q_pe = q_pe.to(torch.bfloat16).contiguous()
def flash_mla():
return flash_mla_with_kvcache_fp8_with_cat(
q_nope,
q_pe,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
descale_q=descale_q,
descale_k=descale_k,
)
def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) * descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
torch.cuda.synchronize()
out_flash, lse_flash = flash_mla()
torch.cuda.synchronize()
out_torch, lse_torch = ref_mla()
# print(" ", out_flash.shape, lse_flash.shape, q.shape)
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" diff ", torch.nonzero((lse_flash - lse_torch).abs() > 0.1))
# print(" diff ", torch.nonzero((out_flash - out_torch).abs() > 0.1))
# print(" out_torch ", out_torch[0, 0, 0, 0:10])
# print(" out_flash ", out_flash[0, 0, 0, 0:10])
# print(" lse_flash ", lse_flash[0, 0:3, :1])
# print(" lse_torch ", lse_torch[0, 0:3, :1])
# print(" nan ", torch.nonzero(torch.isnan(out_flash)))
cal_diff(lse_flash, lse_torch, "lse")
cal_diff(out_flash, out_torch, "out", use_fp8)
if is_prof: return
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
init_dtype = torch.bfloat16 if torch_dtype == torch.float8_e4m3fn else torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
# for b in [40, 80]:
# for s in [3500, 4000, 8192, 16384]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# 压测
for b in [3, 6, 9, 12, 15, 18, 21, 24, 40, 41, 79, 80]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 12345, 45321]:
for h_q in [16, 64]:
for s_q in [1, 2, 3]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype)
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,True,torch_dtype, True)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16, 64]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype, True)
# for b in [1]:
# for s in [128]:
# for h_q in [128]:
# for s_q in [2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [False]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen,False,torch_dtype)
# '''
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16","e4m3"],
default="bf16",
help="Data type to use for testing (bf16/fp16/e4m3)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.float8_e4m3fn
if args.dtype == "fp16":
torch_dtype = torch.float16
elif args.dtype == "e4m3":
torch_dtype = torch.float8_e4m3fn
main(torch_dtype, args.prof)
import argparse
import math
import random
import torch
import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata, flash_mla_with_kvcache_q_nope_pe
# from flash_mla import flash_mla_with_kvcache, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
float("nan")
)
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
q_nope = q[:, :, :, :512].contiguous()
q_pe = q[:, :, :, 512:].contiguous()
def flash_mla():
return flash_mla_with_kvcache_q_nope_pe(
q_nope,
q_pe,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)
# return flash_mla_with_kvcache(
# q,
# blocked_k,
# block_table,
# cache_seqlens,
# dv,
# tile_scheduler_metadata,
# num_splits,
# causal=causal,
# )
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
# print("out_flash:", out_flash)
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
if is_prof: return
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
@torch.inference_mode()
def test_flash_mla_fp8(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
# q = torch.ones(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.half).to(torch.float8_e4m3fn)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
k_scale = torch.tensor(0.17).to(torch.float32).to("cuda:0")
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
k_scale = k_scale,
kv_cache_dtype = "fp8_e4m3"
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
k_scale = k_scale
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
print("out_flash ", out_flash[0, 0, 0, 0:14])
print("out_torch ", out_torch[0, 0, 0, 0:14])
print("lse_flash ", lse_flash[0, 0, 0:10])
print("lse_torch ", lse_torch[0, 0, 0:10])
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [3, 6, 9, 12, 15, 18, 21, 24]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 11111]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False,True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [True]:
# test_flash_mla_fp8(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# '''
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype, args.prof)
import argparse
import math
import random
import torch
import triton
# from flash_mla import flash_mla_with_kvcache_quantization, get_mla_metadata
from flash_mla import flash_mla_with_kvcache_quantization, get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_quantization_q_nope_pe
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
query = query.float()
key = key.float() * k_scale
value = value.float() * k_scale
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
torch_dtype = x.dtype
x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
amax_diff = (x - y).abs().max().item()
print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}")
assert cos_diff < (1e-4 if torch_dtype == torch.bfloat16 else 1e-5)
@torch.inference_mode()
def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
# q = torch.ones(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.half).to(torch.float8_e5m2)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# print("blocked_k ", blocked_k[0, 0, 0, 0:10])
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.5).to(torch.float32).to("cuda:0")
q_nope = q[:, :, :, :512].contiguous()
q_pe = q[:, :, :, 512:].contiguous()
def flash_mla():
return flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope,
q_pe,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
k_scale = k_scale,
kv_cache_dtype = "fp8_e5m2"
)
# return flash_mla_with_kvcache_quantization(
# q,
# blocked_k,
# block_table,
# cache_seqlens,
# dv,
# tile_scheduler_metadata,
# num_splits,
# causal=causal,
# k_scale = k_scale,
# kv_cache_dtype = "fp8_e5m2"
# )
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
k_scale = k_scale
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
# print("out_flash ", out_flash[0, 0, 0, 0:14])
# print("out_torch ", out_torch[0, 0, 0, 0:14])
# print("lse_flash ", lse_flash[0, 0, 0:10])
# print("lse_torch ", lse_torch[0, 0, 0:10])
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
# print(" out ", torch.nonzero((out_flash - out_torch).abs()))
# print(" out_torch", out_torch)
cal_diff(lse_flash, lse_torch, "lse")
cal_diff(out_flash, out_torch, "out")
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = ( b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
) + total_seqlens * h_kv * d
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0")
torch.set_default_dtype(torch_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)
'''
h_kv = 1
d, dv = 576, 512
causal = True
for b in [128]:
for s in [4096, 8192]:
for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1
for s_q in [1, 2]: # MTP = 1, 2
for varlen in [False, True]:
test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# b, s_q, s, h_q, h_kv, d, dv, causal, varlen'''
# test_flash_mla( 1, 1, 64, 16, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 1000, 1, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8( 1, 1, 4096, 8, 1, 576, 512, True, False, is_prof=is_prof)
# test_flash_mla_fp8(32, 1, 4096, 16, 1, 576, 512, False, False, is_prof=is_prof)
# '''
h_kv = 1
d, dv = 576, 512
causal = True
# for b in [1, 32]:
# for s in [200, 1002, 2002, 1024, 2000, 4000, 32768, 65536]:
# for h_q in [4, 16, 32, 64]:
# for s_q in [1, 2]: # MTP = 1, 2
# for varlen in [True]:
# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# for b in [15]:
# for s in [4000]:
# for h_q in [16]:
# for s_q in [1]: # MTP = 1, 2
# for varlen in [False]:
# # for varlen in [True]:
# test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# test_flash_mla_fp8_e4m3(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
# '''
for b in [3, 6, 9, 12, 15, 18, 21, 24]:
for s in [111, 112, 123, 1234, 432, 4325, 4000, 8192, 11111]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False,True]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
for b in [3, 6, 9, 12, 15, 18, 21, 24, 32, 64, 128, 256]:
for s in [4000]:
for h_q in [16]:
for s_q in [1]: # MTP = 1, 2
for varlen in [False]:
test_flash_mla_fp8_e5m2(b, s_q, s, h_q, h_kv, d, dv, causal, varlen)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
choices=["bf16", "fp16"],
default="bf16",
help="Data type to use for testing (bf16 or fp16)",
)
parser.add_argument('--prof', default=False, action='store_true', help='prof or not')
args = parser.parse_args()
torch_dtype = torch.bfloat16
if args.dtype == "fp16":
torch_dtype = torch.float16
main(torch_dtype, args.prof)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment