Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
ed798e2e
Commit
ed798e2e
authored
Mar 06, 2026
by
zhanghj2
Browse files
整理接口
parent
1b95bb9e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
60 additions
and
263 deletions
+60
-263
csrc/extension/flash_api.h
csrc/extension/flash_api.h
+1
-1
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+55
-151
tests/test_flash_mla_qkvfp8_with_cat.py
tests/test_flash_mla_qkvfp8_with_cat.py
+1
-1
tests/test_flash_mla_with_q_concat.py
tests/test_flash_mla_with_q_concat.py
+2
-109
tests/test_flash_mla_with_q_concat_fp8.py
tests/test_flash_mla_with_q_concat_fp8.py
+1
-1
No files found.
csrc/extension/flash_api.h
View file @
ed798e2e
...
...
@@ -337,7 +337,7 @@ mha_fwd_kvcache_mla_nope_pe(
// bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
// TORCH_CHECK(is_sm90);
Arch
arch
=
Arch
();
if
(
!
arch
.
is_gfx93x
())
{
if
(
!
arch
.
is_gfx93x
()
||
!
arch
.
is_gfx928
()
)
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
}
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
...
...
flash_mla/flash_mla_interface.py
View file @
ed798e2e
...
...
@@ -227,9 +227,7 @@ def get_mla_decoding_metadata_dense_fp8(
"""
return
flash_mla_cuda
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
num_heads_per_head_k
,
num_heads_k
)
def
flash_mla_with_kvcache_quantization
(
def
flash_mla_with_kvcache_fp8
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
...
...
@@ -239,30 +237,33 @@ def flash_mla_with_kvcache_quantization(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
k_
scale
=
None
,
kv_cache_dtype
=
None
de
scale
_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
metadata
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
metadata
.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3"
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert
k_scale
is
not
None
and
kv_cache_dtype
is
not
None
,
"k_scale and kv_cache_dtype is not None"
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
quantization_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
mla_fp8
(
q
,
k_cache
,
None
,
...
...
@@ -273,57 +274,12 @@ def flash_mla_with_kvcache_quantization(
causal
,
tile_scheduler_metadata
,
num_splits
,
k_scale
,
kv_cache_dtype
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_q_nope_pe
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla_nope_pe
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
descale_q
,
descale_k
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_
quantization_q_nope_pe
(
def
flash_mla_with_kvcache_
fp8_with_cat
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
...
...
@@ -334,30 +290,36 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
k_
scale
=
None
,
kv_cache_dtype
=
None
de
scale
_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
metadata
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
metadata
.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e4m3"
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert
k_scale
is
not
None
and
kv_cache_dtype
is
not
None
,
"k_scale and kv_cache_dtype is not None"
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
quantization_q_nope_pe_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
mla_fp8_with_cat
(
q_nope
,
q_pe
,
k_cache
,
...
...
@@ -369,8 +331,8 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
causal
,
tile_scheduler_metadata
,
num_splits
,
k_
scale
,
kv_cache_dtype
de
scale
_q
,
descale_k
)
return
out
,
softmax_lse
...
...
@@ -419,9 +381,8 @@ def flash_mla_with_kvcache_q_nope_pe(
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_quantization_q_nope_pe
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
def
flash_mla_with_kvcache_quantization
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
...
...
@@ -440,74 +401,20 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
metadata
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
metadata
.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e
4m3
"
kv_cache_dtype: "only support fp8_e
5m2
"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert
k_scale
is
not
None
and
kv_cache_dtype
is
not
None
,
"k_scale and kv_cache_dtype is not None"
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_q_nope_pe_mla
(
q_nope
,
q_pe
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
k_scale
,
kv_cache_dtype
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_fp8
(
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
head_dim_v
:
int
,
tile_scheduler_metadata
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
support 1) qkv fp8 e4m3 gfx938
2) q bf16/fp16 kv fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
mla_fp8
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
quantization_mla
(
q
,
k_cache
,
None
,
...
...
@@ -518,12 +425,12 @@ def flash_mla_with_kvcache_fp8(
causal
,
tile_scheduler_metadata
,
num_splits
,
de
scale
_q
,
descale_k
k_
scale
,
kv_cache_dtype
)
return
out
,
softmax_lse
def
flash_mla_with_kvcache_
fp8_with_cat
(
def
flash_mla_with_kvcache_
quantization_q_nope_pe
(
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
...
...
@@ -534,36 +441,30 @@ def flash_mla_with_kvcache_fp8_with_cat(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
de
scale
_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
k_
scale
=
None
,
kv_cache_dtype
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
support 1) q_nope q_pe k_cache fp8 e4m3 gfx938
2) q_nope q_pe bf16 k_cache fp8 e4m3 gfx938
3) q_nope q_pe bf16 k_cache fp8 e5m2 gfx936 gfx938
4) q_nope q_pe fp16 k_cache fp8 e5m2 gfx936 gfx938
descale_q descale_k only support 1
Arguments:
q_nope: (batch_size, seq_len_q, num_heads_q, 512).
q_pe: (batch_size, seq_len_q, num_heads_q, 64).
q: (batch_size, seq_len_q, num_heads_q, head_dim).
k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
cache_seqlens: (batch_size), torch.int32.
head_dim_v: Head dimension of v.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
metadata
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
metadata
.
tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
num_splits: (batch_size + 1), torch.int32, returned by get_mla_
decoding_metadata_dense_fp8
.
softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization.
descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization.
k_scale: {1, torch.float32}, tensor shape is 1
kv_cache_dtype: "only support fp8_e5m2"
Returns:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
assert
k_scale
is
not
None
and
kv_cache_dtype
is
not
None
,
"k_scale and kv_cache_dtype is not None"
if
softmax_scale
is
None
:
softmax_scale
=
(
q_nope
.
shape
[
-
1
]
+
q_pe
.
shape
[
-
1
])
**
(
-
0.5
)
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
mla_fp8_with_cat
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_
quantization_q_nope_pe_mla
(
q_nope
,
q_pe
,
k_cache
,
...
...
@@ -575,13 +476,16 @@ def flash_mla_with_kvcache_fp8_with_cat(
causal
,
tile_scheduler_metadata
,
num_splits
,
de
scale
_q
,
descale_k
k_
scale
,
kv_cache_dtype
)
return
out
,
softmax_lse
# def flash_mla_with_kvcache_qkvfp8(
# q: torch.Tensor,
# k_cache: torch.Tensor,
...
...
tests/test_flash_mla_qkvfp8_with_cat.py
View file @
ed798e2e
...
...
@@ -89,7 +89,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
,
h_q
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
init_dtype
=
q
.
dtype
...
...
tests/test_flash_mla_with_q_concat.py
View file @
ed798e2e
...
...
@@ -5,7 +5,7 @@ import random
import
torch
import
triton
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
,
flash_mla_with_kvcache_q_nope_pe
from
flash_mla
import
get_mla_decoding_metadata_dense_fp8
,
flash_mla_with_kvcache_q_nope_pe
# from flash_mla import flash_mla_with_kvcache, get_mla_metadata
torch
.
set_printoptions
(
precision
=
4
,
profile
=
"default"
,
sci_mode
=
False
)
...
...
@@ -67,7 +67,7 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
)
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_
metadata
(
tile_scheduler_metadata
,
num_splits
=
get_mla_
decoding_metadata_dense_fp8
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
...
...
@@ -141,113 +141,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
@
torch
.
inference_mode
()
def
test_flash_mla_fp8
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
,
is_prof
=
False
):
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
cache_seqlens
[
i
]
=
max
(
random
.
normalvariate
(
mean_sk
,
mean_sk
/
2
),
s_q
)
total_seqlens
=
cache_seqlens
.
sum
().
item
()
mean_seqlens
=
cache_seqlens
.
float
().
mean
().
int
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen
,
256
)
*
256
print
(
f
"
{
total_seqlens
=
}
,
{
mean_seqlens
=
}
,
{
max_seqlen
=
}
,
{
max_seqlen_pad
=
}
"
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
# q = torch.ones(b, s_q, h_q, d)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
# blocked_k = torch.randint(low=0, high=4, size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
# blocked_k = torch.ones(size = (block_table.numel(), block_size, h_kv, d), dtype = torch.int8)
blocked_k
=
(
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)).
to
(
torch
.
half
).
to
(
torch
.
float8_e4m3fn
)
# blocked_k[0, 0, 0, 56] = 1
# blocked_k[0, 1, 0, 8] = 2
# blocked_k[0, 2, 0, 8] = 5
# blocked_k[0, 3, 0, 8] = 4
# for i in range(64):
# for j in range(64):
# blocked_k[0, i, 0, j] = j
# blocked_k[0, i, 0, j] = (i * 50 + j) % 128
# for i in range(b):
# blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = (
# -128
# )
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
# print("q:", q.shape, q.dtype, q)
# print("cache_seqlens:", cache_seqlens.shape, cache_seqlens)
# print("block_table:", block_table.shape, block_table)
# print("blocked_k:", blocked_k.shape, blocked_k[0])
# print("blocked_v:", blocked_v.shape)
# torch.set_printoptions(precision=4, profile="full", sci_mode=False)
# print("tile_scheduler_metadata:", tile_scheduler_metadata.shape, tile_scheduler_metadata)
# torch.set_printoptions(precision=4, profile="default", sci_mode=False)
# print("num_splits:", num_splits.shape, num_splits)
k_scale
=
torch
.
tensor
(
0.17
).
to
(
torch
.
float32
).
to
(
"cuda:0"
)
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
k_scale
=
k_scale
,
kv_cache_dtype
=
"fp8_e4m3"
)
def
ref_mla
():
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
h_q
=
h_q
,
h_kv
=
h_kv
,
is_causal
=
causal
,
k_scale
=
k_scale
)
out
[
i
]
=
O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
print
(
"out_flash "
,
out_flash
[
0
,
0
,
0
,
0
:
14
])
print
(
"out_torch "
,
out_torch
[
0
,
0
,
0
,
0
:
14
])
print
(
"lse_flash "
,
lse_flash
[
0
,
0
,
0
:
10
])
print
(
"lse_torch "
,
lse_torch
[
0
,
0
,
0
:
10
])
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
print
(
"out max_diff "
,
(
out_flash
-
out_torch
).
abs
().
max
())
print
(
"lse max_diff "
,
(
lse_flash
-
lse_torch
).
abs
().
max
())
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
q
.
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
main
(
torch_dtype
,
is_prof
=
False
):
device
=
torch
.
device
(
"cuda:0"
)
...
...
tests/test_flash_mla_with_q_concat_fp8.py
View file @
ed798e2e
...
...
@@ -6,7 +6,7 @@ import torch
import
triton
# from flash_mla import flash_mla_with_kvcache_quantization, get_mla_metadata
from
flash_mla
import
flash_mla_with_kvcache_fp8_with_cat
,
get_mla_decoding_metadata_dense_fp8
,
flash_mla_with_kvcache_quantization_q_nope_pe
from
flash_mla
import
flash_mla_with_kvcache_fp8_with_cat
,
get_mla_decoding_metadata_dense_fp8
torch
.
set_printoptions
(
precision
=
4
,
profile
=
"default"
,
sci_mode
=
False
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
,
k_scale
=
1.0
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment