Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
702e8c22
Commit
702e8c22
authored
Mar 05, 2026
by
zhanghj2
Browse files
e5m2接口合并到flash_mla_with_kvcache_fp8
parent
7949f854
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
19 deletions
+34
-19
csrc/api/common.h
csrc/api/common.h
+6
-0
csrc/extension/flash_api.h
csrc/extension/flash_api.h
+19
-13
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+2
-0
tests/test_flash_mla_fp8.py
tests/test_flash_mla_fp8.py
+7
-6
No files found.
csrc/api/common.h
View file @
702e8c22
...
@@ -271,3 +271,9 @@ public:
...
@@ -271,3 +271,9 @@ public:
}
}
};
};
std
::
string
getDtypeString
(
const
torch
::
Tensor
&
tensor
)
{
std
::
string
dtype_str
=
c10
::
toString
(
tensor
.
scalar_type
());
return
dtype_str
;
}
csrc/extension/flash_api.h
View file @
702e8c22
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <cstdlib>
#include <cstdlib>
#include "flash_mla.h"
#include "flash_mla.h"
#include "static_switch.h"
#include "static_switch.h"
#include "../api/common.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
...
@@ -677,18 +678,14 @@ mha_fwd_kvcache_mla_fp8(
...
@@ -677,18 +678,14 @@ mha_fwd_kvcache_mla_fp8(
const
std
::
optional
<
at
::
Tensor
>
&
descale_q
,
// None or batch_size
const
std
::
optional
<
at
::
Tensor
>
&
descale_q
,
// None or batch_size
const
std
::
optional
<
at
::
Tensor
>
&
descale_k
// None or batch_size
const
std
::
optional
<
at
::
Tensor
>
&
descale_k
// None or batch_size
)
{
)
{
// auto dprops = at::cuda::getCurrentDeviceProperties();
Arch
arch
=
Arch
();
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
if
(
!
arch
.
is_gfx93x
())
{
// TORCH_CHECK(is_sm90);
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
// static std::string FLASH_MLA_ROOT_DIR = execCommand("python -c 'import site; print(site.getsitepackages()[0])'");
}
// setenv("FLASH_MLA_ROOT_DIR", (FLASH_MLA_ROOT_DIR + "/flash_mla/asm/").c_str(), 1);
// std::cout << FLASH_MLA_ROOT_DIR << "\n";
// exit(-1);
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
auto
q_dtype
=
q
.
dtype
();
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
//
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
...
@@ -796,6 +793,7 @@ mha_fwd_kvcache_mla_fp8(
...
@@ -796,6 +793,7 @@ mha_fwd_kvcache_mla_fp8(
params
.
descale_q_ptr
=
reinterpret_cast
<
float
*>
(
descale_q
.
value
().
data_ptr
());
params
.
descale_q_ptr
=
reinterpret_cast
<
float
*>
(
descale_q
.
value
().
data_ptr
());
params
.
descale_k_ptr
=
reinterpret_cast
<
float
*>
(
descale_k
.
value
().
data_ptr
());
params
.
descale_k_ptr
=
reinterpret_cast
<
float
*>
(
descale_k
.
value
().
data_ptr
());
params
.
k_scale_ptr
=
descale_k_
.
data_ptr
();
TORCH_CHECK
(
tile_scheduler_metadata
.
dtype
()
==
torch
::
kInt32
,
"tile_scheduler_metadata must have dtype int32"
);
TORCH_CHECK
(
tile_scheduler_metadata
.
dtype
()
==
torch
::
kInt32
,
"tile_scheduler_metadata must have dtype int32"
);
TORCH_CHECK
(
tile_scheduler_metadata
.
size
(
1
)
==
TileSchedulerMetaDataSize
);
TORCH_CHECK
(
tile_scheduler_metadata
.
size
(
1
)
==
TileSchedulerMetaDataSize
);
...
@@ -821,11 +819,19 @@ mha_fwd_kvcache_mla_fp8(
...
@@ -821,11 +819,19 @@ mha_fwd_kvcache_mla_fp8(
batch_size
,
seqlen_q_ori
,
num_heads_ori
,
head_size
,
batch_size
,
max_num_blocks_per_seq
,
batch_size
,
seqlen_q_ori
,
num_heads_ori
,
head_size
,
batch_size
,
max_num_blocks_per_seq
,
num_blocks
,
page_block_size
,
num_heads_k
,
head_size_k
,
is_causal
,
softmax_scale
);
num_blocks
,
page_block_size
,
num_heads_k
,
head_size_k
,
is_causal
,
softmax_scale
);
}
}
if
(
q_dtype
==
torch
::
kFloat8_e4m3fn
)
{
if
(
q_dtype
==
torch
::
kFloat8_e4m3fn
&&
kcache
.
dtype
()
==
q_dtype
)
{
if
(
!
arch
.
is_gfx938
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx938 architecture"
);
}
run_mha_fwd_splitkv_mla_fp8
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
,
false
);
run_mha_fwd_splitkv_mla_fp8
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
,
false
);
}
else
if
((
q_dtype
==
torch
::
kBFloat16
||
q_dtype
==
torch
::
kHalf
)
&&
kcache
.
dtype
()
==
torch
::
kFloat8_e5m2
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
run_mha_fwd_splitkv_mla
<
cutlass
::
bfloat16_t
,
576
>
(
params
,
"fp8_e5m2"
,
stream
);
}
else
{
run_mha_fwd_splitkv_mla
<
cutlass
::
half_t
,
576
>
(
params
,
"fp8_e5m2"
,
stream
);
}
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype
for query"
);
TORCH_CHECK
(
false
,
"Unsupported tensor dtype
, q dtype "
+
getDtypeString
(
q
)
+
" kvcache "
+
getDtypeString
(
kcache
)
);
}
}
out
=
out
.
view
({
batch_size
,
seqlen_q_ori
,
ngroups
,
num_heads_k
,
head_size_v
}).
transpose
(
2
,
3
)
out
=
out
.
view
({
batch_size
,
seqlen_q_ori
,
ngroups
,
num_heads_k
,
head_size_v
}).
transpose
(
2
,
3
)
...
...
flash_mla/flash_mla_interface.py
View file @
702e8c22
...
@@ -485,6 +485,8 @@ def flash_mla_with_kvcache_fp8(
...
@@ -485,6 +485,8 @@ def flash_mla_with_kvcache_fp8(
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
Arguments:
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
...
...
tests/test_flash_mla_fp8.py
View file @
702e8c22
...
@@ -5,7 +5,7 @@ import random
...
@@ -5,7 +5,7 @@ import random
import
torch
import
torch
import
triton
import
triton
from
flash_mla
import
flash_mla_with_kvcache_
quantization
,
get_mla_decoding_metadata_dense_fp8
from
flash_mla
import
flash_mla_with_kvcache_
fp8
,
get_mla_decoding_metadata_dense_fp8
torch
.
set_printoptions
(
precision
=
4
,
profile
=
"default"
,
sci_mode
=
False
)
torch
.
set_printoptions
(
precision
=
4
,
profile
=
"default"
,
sci_mode
=
False
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
,
k_scale
=
1.0
):
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
,
k_scale
=
1.0
):
...
@@ -62,7 +62,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
...
@@ -62,7 +62,7 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k
=
(
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)).
to
(
torch
.
half
).
to
(
torch
.
float8_e5m2
)
blocked_k
=
(
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)).
to
(
torch
.
float8_e5m2
)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 2, 0, 8] = 5
...
@@ -93,9 +93,10 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
...
@@ -93,9 +93,10 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
# print("num_splits:", num_splits.shape, num_splits)
# print("num_splits:", num_splits.shape, num_splits)
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(1.0).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
# k_scale = torch.tensor(2.1).to(torch.float32).to("cuda:0")
k_scale
=
torch
.
tensor
(
1.0
).
to
(
torch
.
float32
).
to
(
"cuda:0"
)
descale_q
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
descale_k
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
def
flash_mla
():
def
flash_mla
():
return
flash_mla_with_kvcache_
quantization
(
return
flash_mla_with_kvcache_
fp8
(
q
,
q
,
blocked_k
,
blocked_k
,
block_table
,
block_table
,
...
@@ -104,8 +105,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
...
@@ -104,8 +105,8 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
causal
=
causal
,
causal
=
causal
,
k_
scale
=
k_
scale
,
de
scale
_q
=
de
scale
_q
,
kv_cache_dtype
=
"fp8_e5m2"
descale_k
=
descale_k
,
)
)
def
ref_mla
():
def
ref_mla
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment