Commit 1b95bb9e authored by zhanghj2's avatar zhanghj2
Browse files

融合类算子合并到flash_mla_with_kvcache_fp8_with_cat

parent 113ee450
......@@ -50,8 +50,10 @@ mha_fwd_kvcache_quantization_mla(
const at::Tensor &k_scale,
const std::string &kv_cache_dtype
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_gfx936 = dprops->major == 9 && dprops->minor == 3;
Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
// TORCH_CHECK(is_sm90);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
......@@ -62,7 +64,7 @@ mha_fwd_kvcache_quantization_mla(
TORCH_CHECK(kcache.dtype() != q_dtype, "非量化情况下, query and key must have not the same dtype");
CHECK_DEVICE(k_scale);
TORCH_CHECK(k_scale.dtype() == torch::kFloat32, "非量化情况下, query and key must have the same dtype");
TORCH_CHECK(is_gfx936, "fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures");
// TORCH_CHECK(is_gfx936, "fp8_e4m3 and fp8_e5m2 Attention Forward Kernel (mha_fwd_kvcache_quantization_mla) is only supported on gfx936 architectures");
}
else
{
......@@ -334,7 +336,10 @@ mha_fwd_kvcache_mla_nope_pe(
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
......@@ -502,7 +507,10 @@ mha_fwd_kvcache_quantization_q_nope_pe_mla(
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
......@@ -827,8 +835,10 @@ mha_fwd_kvcache_mla_fp8(
} else if ((q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf) && kcache.dtype() == torch::kFloat8_e5m2) {
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, "fp8_e5m2", stream);
} else {
} else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, "fp8_e5m2", stream);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype, q dtype " + getDtypeString(q) + " kvcache " + getDtypeString(kcache));
}
} else {
TORCH_CHECK(false, "Unsupported tensor dtype, q dtype " + getDtypeString(q) + " kvcache " + getDtypeString(kcache));
......@@ -860,11 +870,14 @@ mha_fwd_kvcache_mla_fp8_with_cat(
// auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
Arch arch = Arch();
if (!arch.is_gfx93x()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
}
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q_nope.dtype();
// TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(q_pe.dtype() == q_dtype, "query nope and q_pe must have the same dtype");
CHECK_DEVICE(q_nope); CHECK_DEVICE(q_pe); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
......@@ -982,6 +995,7 @@ mha_fwd_kvcache_mla_fp8_with_cat(
params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr());
params.k_scale_ptr = descale_k_.data_ptr();
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
......@@ -1009,12 +1023,23 @@ mha_fwd_kvcache_mla_fp8_with_cat(
}
if (q_dtype == torch::kFloat8_e4m3fn && kcache.dtype() == torch::kFloat8_e4m3fn)
{
if (!arch.is_gfx938()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx938 architecture");
}
run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 576>(params,stream,true);
} else if (q_dtype == torch::kBFloat16 && kcache.dtype() == torch::kFloat8_e4m3fn) {
run_mha_fwd_splitkv_mla_fp8<cutlass::bfloat16_t,cutlass::bfloat16_t, 576>(params,stream,true);
} else if (kcache.dtype() == torch::kFloat8_e5m2) {
if (q_dtype == torch::kBFloat16) {
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, "fp8_e5m2", stream, true);
} else if (q_dtype == torch::kHalf) {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, "fp8_e5m2", stream, true);
} else {
TORCH_CHECK(false, "Unsupported tensor dtype, q dtype " + getDtypeString(q_nope) + " kvcache " + getDtypeString(kcache));
}
}
else {
TORCH_CHECK(false, "Unsupported tensor dtype for query");
TORCH_CHECK(false, "Unsupported tensor dtype, q dtype " + getDtypeString(q_pe) + " kvcache " + getDtypeString(kcache));
}
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
......
......@@ -487,6 +487,7 @@ def flash_mla_with_kvcache_fp8(
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
......@@ -537,6 +538,11 @@ def flash_mla_with_kvcache_fp8_with_cat(
descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
......
......@@ -6,7 +6,7 @@ import torch
import triton
# from flash_mla import flash_mla_with_kvcache_quantization, get_mla_metadata
from flash_mla import flash_mla_with_kvcache_quantization, get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_quantization_q_nope_pe
from flash_mla import flash_mla_with_kvcache_fp8_with_cat, get_mla_decoding_metadata_dense_fp8, flash_mla_with_kvcache_quantization_q_nope_pe
torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
......@@ -97,8 +97,10 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# k_scale = torch.tensor(2.5).to(torch.float32).to("cuda:0")
q_nope = q[:, :, :, :512].contiguous()
q_pe = q[:, :, :, 512:].contiguous()
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
def flash_mla():
return flash_mla_with_kvcache_quantization_q_nope_pe(
return flash_mla_with_kvcache_fp8_with_cat(
q_nope,
q_pe,
blocked_k,
......@@ -108,8 +110,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
tile_scheduler_metadata,
num_splits,
causal=causal,
k_scale = k_scale,
kv_cache_dtype = "fp8_e5m2"
descale_q=descale_q,
descale_k=descale_k,
)
# return flash_mla_with_kvcache_quantization(
# q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment