Commit 702e8c22 authored by zhanghj2's avatar zhanghj2
Browse files

e5m2接口合并到flash_mla_with_kvcache_fp8

parent 7949f854
...@@ -271,3 +271,9 @@ public: ...@@ -271,3 +271,9 @@ public:
} }
}; };
std::string getDtypeString(const torch::Tensor& tensor) {
std::string dtype_str = c10::toString(tensor.scalar_type());
return dtype_str;
}
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <cstdlib> #include <cstdlib>
#include "flash_mla.h" #include "flash_mla.h"
#include "static_switch.h" #include "static_switch.h"
#include "../api/common.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
...@@ -677,18 +678,14 @@ mha_fwd_kvcache_mla_fp8( ...@@ -677,18 +678,14 @@ mha_fwd_kvcache_mla_fp8(
const std::optional<at::Tensor> &descale_q, // None or batch_size const std::optional<at::Tensor> &descale_q, // None or batch_size
const std::optional<at::Tensor> &descale_k // None or batch_size const std::optional<at::Tensor> &descale_k // None or batch_size
) { ) {
// auto dprops = at::cuda::getCurrentDeviceProperties(); Arch arch = Arch();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0; if (!arch.is_gfx93x()) {
// TORCH_CHECK(is_sm90); TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
// static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'"); }
// setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// std::cout << FLASH_MLA_ROOT_DIR << "\n";
// exit(-1);
at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache;
auto q_dtype = q.dtype(); auto q_dtype = q.dtype();
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); // TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
...@@ -796,6 +793,7 @@ mha_fwd_kvcache_mla_fp8( ...@@ -796,6 +793,7 @@ mha_fwd_kvcache_mla_fp8(
params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr()); params.descale_q_ptr = reinterpret_cast<float *>(descale_q.value().data_ptr());
params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr()); params.descale_k_ptr = reinterpret_cast<float *>(descale_k.value().data_ptr());
params.k_scale_ptr = descale_k_.data_ptr();
TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32");
TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize);
...@@ -821,11 +819,19 @@ mha_fwd_kvcache_mla_fp8( ...@@ -821,11 +819,19 @@ mha_fwd_kvcache_mla_fp8(
batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq, batch_size, seqlen_q_ori, num_heads_ori, head_size, batch_size, max_num_blocks_per_seq,
num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale); num_blocks, page_block_size, num_heads_k, head_size_k, is_causal, softmax_scale);
} }
if (q_dtype == torch::kFloat8_e4m3fn) { if (q_dtype == torch::kFloat8_e4m3fn && kcache.dtype() == q_dtype) {
if (!arch.is_gfx938()) {
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx938 architecture");
}
run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 576>(params,stream,false); run_mha_fwd_splitkv_mla_fp8<cutlass::float_e4m3_t,cutlass::bfloat16_t, 576>(params,stream,false);
} } else if ((q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf) && kcache.dtype() == torch::kFloat8_e5m2) {
else { if (q_dtype == torch::kBFloat16) {
TORCH_CHECK(false, "Unsupported tensor dtype for query"); run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(params, "fp8_e5m2", stream);
} else {
run_mha_fwd_splitkv_mla<cutlass::half_t, 576>(params, "fp8_e5m2", stream);
}
} else {
TORCH_CHECK(false, "Unsupported tensor dtype, q dtype " + getDtypeString(q) + " kvcache " + getDtypeString(kcache));
} }
out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
......
...@@ -485,6 +485,8 @@ def flash_mla_with_kvcache_fp8( ...@@ -485,6 +485,8 @@ def flash_mla_with_kvcache_fp8(
descale_k: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
Arguments: Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim). q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
......
...@@ -5,7 +5,7 @@ import random ...@@ -5,7 +5,7 @@ import random
import torch import torch
import triton import triton
from flash_mla import flash_mla_with_kvcache_quantization, get_mla_decoding_metadata_dense_fp8 from flash_mla import flash_mla_with_kvcache_fp8, get_mla_decoding_metadata_dense_fp8
torch.set_printoptions(precision=4, profile="default", sci_mode=False) torch.set_printoptions(precision=4, profile="default", sci_mode=False)
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0): def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False, k_scale=1.0):
...@@ -62,7 +62,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i ...@@ -62,7 +62,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8) # blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8) # blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.half).to(torch.float8_e5m2) blocked_k = (torch.randn(block_table.numel(), block_size, h_kv, d)).to(torch.float8_e5m2)
# blocked_k[0, 0, 0, 56] = 1 # blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2 # blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5 # blocked_k[0, 2, 0, 8] = 5
...@@ -93,9 +93,10 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i ...@@ -93,9 +93,10 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# print("num_splits:", num_splits.shape, num_splits) # print("num_splits:", num_splits.shape, num_splits)
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0") # k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0") # k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0") descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)
def flash_mla(): def flash_mla():
return flash_mla_with_kvcache_quantization( return flash_mla_with_kvcache_fp8(
q, q,
blocked_k, blocked_k,
block_table, block_table,
...@@ -104,8 +105,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i ...@@ -104,8 +105,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
causal=causal, causal=causal,
k_scale = k_scale, descale_q = descale_q,
kv_cache_dtype = "fp8_e5m2" descale_k = descale_k,
) )
def ref_mla(): def ref_mla():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment