Commit ed798e2e authored by zhanghj2's avatar zhanghj2
Browse files

整理接口

parent 1b95bb9e
...@@ -337,7 +337,7 @@ mha_fwd_kvcache_mla_nope_pe( ...@@ -337,7 +337,7 @@ mha_fwd_kvcache_mla_nope_pe(
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0; // bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90); // TORCH_CHECK(is_sm90);
Arch arch = Arch(); Arch arch = Arch();
if (!arch.is_gfx93x()) { if (!arch.is_gfx93x() || !arch.is_gfx928()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture"); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
} }
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
......
...@@ -227,9 +227,7 @@ def get_mla_decoding_metadata_dense_fp8( ...@@ -227,9 +227,7 @@ def get_mla_decoding_metadata_dense_fp8(
""" """
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k) return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, num_heads_per_head_k, num_heads_k)
def flash_mla_with_kvcache_fp8(
def flash_mla_with_kvcache_quantization(
q: torch.Tensor, q: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
block_table: torch.Tensor, block_table: torch.Tensor,
...@@ -239,30 +237,33 @@ def flash_mla_with_kvcache_quantization( ...@@ -239,30 +237,33 @@ def flash_mla_with_kvcache_quantization(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
k_scale = None, descale_q: Optional[torch.Tensor] = None,
kv_cache_dtype = None descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments: Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim). q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v. head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask. causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1 descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
kv_cache_dtype: "only support fp8_e4m3" descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns: Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
""" """
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8(
q, q,
k_cache, k_cache,
None, None,
...@@ -273,57 +274,12 @@ def flash_mla_with_kvcache_quantization( ...@@ -273,57 +274,12 @@ def flash_mla_with_kvcache_quantization(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
k_scale, descale_q,
kv_cache_dtype descale_k
)
return out, softmax_lse
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits
) )
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_quantization_q_nope_pe( def flash_mla_with_kvcache_fp8_with_cat(
q_nope: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
...@@ -334,30 +290,36 @@ def flash_mla_with_kvcache_quantization_q_nope_pe( ...@@ -334,30 +290,36 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
k_scale = None, descale_q: Optional[torch.Tensor] = None,
kv_cache_dtype = None descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments: Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim). q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v. head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask. causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1 descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
kv_cache_dtype: "only support fp8_e4m3" descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns: Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
""" """
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None: if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5) softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat(
q_nope, q_nope,
q_pe, q_pe,
k_cache, k_cache,
...@@ -369,8 +331,8 @@ def flash_mla_with_kvcache_quantization_q_nope_pe( ...@@ -369,8 +331,8 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
k_scale, descale_q,
kv_cache_dtype descale_k
) )
return out, softmax_lse return out, softmax_lse
...@@ -419,9 +381,8 @@ def flash_mla_with_kvcache_q_nope_pe( ...@@ -419,9 +381,8 @@ def flash_mla_with_kvcache_q_nope_pe(
) )
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_quantization_q_nope_pe( def flash_mla_with_kvcache_quantization(
q_nope: torch.Tensor, q: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
block_table: torch.Tensor, block_table: torch.Tensor,
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
...@@ -440,74 +401,20 @@ def flash_mla_with_kvcache_quantization_q_nope_pe( ...@@ -440,74 +401,20 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v. head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask. causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1 k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3" kv_cache_dtype: "only support fp8_e5m2"
Returns: Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
""" """
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None" assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_cache_dtype
)
return out, softmax_lse
def flash_mla_with_kvcache_fp8(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
descale_q: Optional[torch.Tensor] = None,
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q, q,
k_cache, k_cache,
None, None,
...@@ -518,12 +425,12 @@ def flash_mla_with_kvcache_fp8( ...@@ -518,12 +425,12 @@ def flash_mla_with_kvcache_fp8(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
descale_q, k_scale,
descale_k kv_cache_dtype
) )
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_fp8_with_cat( def flash_mla_with_kvcache_quantization_q_nope_pe(
q_nope: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
k_cache: torch.Tensor, k_cache: torch.Tensor,
...@@ -534,36 +441,30 @@ def flash_mla_with_kvcache_fp8_with_cat( ...@@ -534,36 +441,30 @@ def flash_mla_with_kvcache_fp8_with_cat(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
descale_q: Optional[torch.Tensor] = None, k_scale = None,
descale_k: Optional[torch.Tensor] = None, kv_cache_dtype = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments: Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512). q: (batch_size, seq_len_q, num_heads_q, head_dim).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32. block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32. cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v. head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, returned by get_mla_decoding_metadata_dense_fp8.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask. causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. k_scale: {1, torch.float32}, tensor shape is 1
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. kv_cache_dtype: "only support fp8_e5m2"
Returns: Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v). out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
""" """
assert k_scale is not None and kv_cache_dtype is not None, "k_scale and kv_cache_dtype is not None"
if softmax_scale is None: if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5) softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_fp8_with_cat( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope, q_nope,
q_pe, q_pe,
k_cache, k_cache,
...@@ -575,13 +476,16 @@ def flash_mla_with_kvcache_fp8_with_cat( ...@@ -575,13 +476,16 @@ def flash_mla_with_kvcache_fp8_with_cat(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
descale_q, k_scale,
descale_k kv_cache_dtype
) )
return out, softmax_lse return out, softmax_lse
# def flash_mla_with_kvcache_qkvfp8( # def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor, # q: torch.Tensor,
# k_cache: torch.Tensor, # k_cache: torch.Tensor,
......
...@@ -89,7 +89,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa ...@@ -89,7 +89,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8( tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv, h_q cache_seqlens, s_q * h_q // h_kv, h_kv
) )
init_dtype = q.dtype init_dtype = q.dtype
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import torch import torch
import triton import triton
from flash_mla import flash_mla_with_kvcache, get_mla_metadata, flash_mla_with_kvcache_q_nope_pe from flash_mla import get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_q_nope_pe
# from flash_mla import flash_mla_with_kvcache, get_mla_metadata # from flash_mla import flash_mla_with_kvcache, get_mla_metadata
torch.set_printoptions(precision=4, profile="default", sci_mode=False) torch.set_printoptions(precision=4, profile="default", sci_mode=False)
...@@ -67,7 +67,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa ...@@ -67,7 +67,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
) )
blocked_v = blocked_k[..., :dv] blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata( tile_scheduler_metadata, num_splits = get_mla_decoding_metadata_dense_fp8(
cache_seqlens, s_q * h_q // h_kv, h_kv cache_seqlens, s_q * h_q // h_kv, h_kv
) )
...@@ -141,113 +141,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa ...@@ -141,113 +141,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s" f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
) )
@torch.inference_mode()
def test_flash_mla_fp8(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=False):
print(
f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}"
)
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256
print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}, {max_seqlen_pad=}")
q = torch.randn(b, s_q, h_q, d)
# q = torch.ones(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32
).view(b, max_seqlen_pad // block_size)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.half).to(torch.float8_e4m3fn)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v = blocked_k[..., :dv]
tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
k_scale = torch.tensor(0.17).to(torch.float32).to("cuda:0")
def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
k_scale = k_scale,
kv_cache_dtype = "fp8_e4m3"
)
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
k_scale = k_scale
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse
out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
print("out_flash ", out_flash[0, 0, 0, 0:14])
print("out_torch ", out_torch[0, 0, 0, 0:14])
print("lse_flash ", lse_flash[0, 0, 0:10])
print("lse_torch ", lse_torch[0, 0, 0:10])
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")
print("out max_diff ", (out_flash - out_torch).abs().max())
print("lse max_diff ", (lse_flash - lse_torch).abs().max())
t = triton.testing.do_bench(flash_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(q.dtype).bits // 8
)
print(
f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s"
)
def main(torch_dtype, is_prof=False): def main(torch_dtype, is_prof=False):
device = torch.device("cuda:0") device = torch.device("cuda:0")
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import triton import triton
# from flash_mla import flash_mla_with_kvcache_quantization, get_mla_metadata # from flash_mla import flash_mla_with_kvcache_quantization, get_mla_metadata
from flash_mla import flash_mla_with_kvcache_fp8_with_cat, get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_quantization_q_nope_pe from flash_mla import flash_mla_with_kvcache_fp8_with_cat, get_mla_decoding_metadata_dense_fp8
torch.set_printoptions(precision=4, profile="default", sci_mode=False) torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment