Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
65fb7732
Commit
65fb7732
authored
Feb 24, 2025
by
Sijia Chen
Browse files
support fp16
parent
15a82b81
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
139 additions
and
91 deletions
+139
-91
README.md
README.md
+1
-1
csrc/flash_api.cpp
csrc/flash_api.cpp
+7
-2
csrc/flash_fwd_mla_fp16_sm90.cu
csrc/flash_fwd_mla_fp16_sm90.cu
+3
-0
csrc/flash_fwd_mla_kernel.h
csrc/flash_fwd_mla_kernel.h
+0
-76
csrc/flash_fwd_mla_metadata.cu
csrc/flash_fwd_mla_metadata.cu
+77
-0
setup.py
setup.py
+2
-0
tests/test_flash_mla.py
tests/test_flash_mla.py
+49
-12
No files found.
README.md
View file @
65fb7732
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
Currently released:
-
BF16
-
BF16
, FP16
-
Paged kvcache with block size of 64
-
Paged kvcache with block size of 64
## Quick start
## Quick start
...
...
csrc/flash_api.cpp
View file @
65fb7732
...
@@ -77,7 +77,7 @@ mha_fwd_kvcache_mla(
...
@@ -77,7 +77,7 @@ mha_fwd_kvcache_mla(
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
auto
q_dtype
=
q
.
dtype
();
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kBFloat16
);
TORCH_CHECK
(
q_dtype
==
torch
::
kBFloat16
||
q_dtype
==
torch
::
kFloat16
);
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
...
@@ -186,7 +186,12 @@ mha_fwd_kvcache_mla(
...
@@ -186,7 +186,12 @@ mha_fwd_kvcache_mla(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
head_size
==
576
);
TORCH_CHECK
(
head_size
==
576
);
run_mha_fwd_splitkv_mla
<
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
run_mha_fwd_splitkv_mla
<
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
);
}
else
{
run_mha_fwd_splitkv_mla
<
cutlass
::
half_t
,
576
>
(
params
,
stream
);
}
out
=
out
.
view
({
batch_size
,
seqlen_q_ori
,
ngroups
,
num_heads_k
,
head_size_v
}).
transpose
(
2
,
3
)
out
=
out
.
view
({
batch_size
,
seqlen_q_ori
,
ngroups
,
num_heads_k
,
head_size_v
}).
transpose
(
2
,
3
)
.
reshape
({
batch_size
,
seqlen_q_ori
,
num_heads_ori
,
head_size_v
});
.
reshape
({
batch_size
,
seqlen_q_ori
,
num_heads_ori
,
head_size_v
});
...
...
csrc/flash_fwd_mla_fp16_sm90.cu
0 → 100644
View file @
65fb7732
#include "flash_fwd_mla_kernel.h"
template
void
run_mha_fwd_splitkv_mla
<
cutlass
::
half_t
,
576
>(
Flash_fwd_mla_params
&
params
,
cudaStream_t
stream
);
csrc/flash_fwd_mla_kernel.h
View file @
65fb7732
...
@@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream)
...
@@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream)
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla
<
576
,
64
,
64
,
8
,
T
,
512
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla
<
576
,
64
,
64
,
8
,
T
,
512
>
;
run_flash_splitkv_fwd_mla
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>>
(
params
,
stream
);
run_flash_splitkv_fwd_mla
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>>
(
params
,
stream
);
}
}
static
constexpr
int
MaxBatchSize
=
4096
;
__global__
void
__launch_bounds__
(
256
,
1
,
1
)
get_mla_metadata_kernel
(
__grid_constant__
const
Mla_metadata_params
params
)
{
int
*
seqlens_k_ptr
=
params
.
seqlens_k_ptr
;
int
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
;
int
*
num_splits_ptr
=
params
.
num_splits_ptr
;
int
batch_size
=
params
.
batch_size
;
int
block_size_n
=
params
.
block_size_n
;
int
fixed_overhead_num_blocks
=
params
.
fixed_overhead_num_blocks
;
int
num_sm_parts
=
params
.
num_sm_parts
;
__shared__
int
num_blocks_shared
[
MaxBatchSize
];
__shared__
int
num_splits_shared
[
MaxBatchSize
];
int
total_num_blocks
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
32
)
{
int
num_blocks
=
cutlass
::
ceil_div
(
seqlens_k_ptr
[
i
],
block_size_n
);
total_num_blocks
+=
num_blocks
+
fixed_overhead_num_blocks
;
num_blocks_shared
[
i
]
=
num_blocks
;
}
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
total_num_blocks
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
total_num_blocks
,
offset
);
}
__syncwarp
();
if
(
threadIdx
.
x
==
0
)
{
int
payload
=
cutlass
::
ceil_div
(
total_num_blocks
,
num_sm_parts
)
+
fixed_overhead_num_blocks
;
int
now_idx
=
0
,
now_block
=
0
,
now_n_split_idx
=
0
,
cum_num_splits
=
0
;
num_splits_shared
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_sm_parts
;
++
i
)
{
int
tile_scheduler_metadata0
[
4
],
tile_scheduler_metadata1
;
tile_scheduler_metadata0
[
0
]
=
now_idx
;
tile_scheduler_metadata0
[
1
]
=
now_block
*
block_size_n
;
tile_scheduler_metadata1
=
now_n_split_idx
;
int
remain_payload
=
payload
;
while
(
now_idx
<
batch_size
)
{
int
num_blocks
=
num_blocks_shared
[
now_idx
];
int
now_remain_blocks
=
num_blocks
-
now_block
;
if
(
remain_payload
>=
now_remain_blocks
+
fixed_overhead_num_blocks
)
{
cum_num_splits
+=
now_n_split_idx
+
1
;
num_splits_shared
[
now_idx
+
1
]
=
cum_num_splits
;
remain_payload
-=
now_remain_blocks
+
fixed_overhead_num_blocks
;
++
now_idx
;
now_block
=
0
;
now_n_split_idx
=
0
;
}
else
{
if
(
remain_payload
-
fixed_overhead_num_blocks
>
0
)
{
now_block
+=
remain_payload
-
fixed_overhead_num_blocks
;
++
now_n_split_idx
;
remain_payload
=
0
;
}
break
;
}
}
tile_scheduler_metadata0
[
2
]
=
now_block
>
0
?
now_idx
:
now_idx
-
1
;
tile_scheduler_metadata0
[
3
]
=
now_block
>
0
?
now_block
*
block_size_n
:
seqlens_k_ptr
[
now_idx
-
1
];
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
+
i
*
TileSchedulerMetaDataSize
)
=
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata0
);
tile_scheduler_metadata_ptr
[
i
*
TileSchedulerMetaDataSize
+
4
]
=
tile_scheduler_metadata1
;
}
FLASH_DEVICE_ASSERT
(
now_idx
==
batch_size
&&
now_block
==
0
&&
now_n_split_idx
==
0
);
}
__syncwarp
();
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch_size
;
i
+=
32
)
{
num_splits_ptr
[
i
]
=
num_splits_shared
[
i
];
}
}
void
get_mla_metadata_func
(
Mla_metadata_params
&
params
,
cudaStream_t
stream
)
{
FLASH_ASSERT
(
params
.
batch_size
<
MaxBatchSize
);
get_mla_metadata_kernel
<<<
1
,
32
,
0
,
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
}
csrc/flash_fwd_mla_metadata.cu
0 → 100644
View file @
65fb7732
#include "flash_fwd_mla_kernel.h"
static
constexpr
int
MaxBatchSize
=
4096
;
__global__
void
__launch_bounds__
(
256
,
1
,
1
)
get_mla_metadata_kernel
(
__grid_constant__
const
Mla_metadata_params
params
)
{
int
*
seqlens_k_ptr
=
params
.
seqlens_k_ptr
;
int
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
;
int
*
num_splits_ptr
=
params
.
num_splits_ptr
;
int
batch_size
=
params
.
batch_size
;
int
block_size_n
=
params
.
block_size_n
;
int
fixed_overhead_num_blocks
=
params
.
fixed_overhead_num_blocks
;
int
num_sm_parts
=
params
.
num_sm_parts
;
__shared__
int
num_blocks_shared
[
MaxBatchSize
];
__shared__
int
num_splits_shared
[
MaxBatchSize
];
int
total_num_blocks
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
32
)
{
int
num_blocks
=
cutlass
::
ceil_div
(
seqlens_k_ptr
[
i
],
block_size_n
);
total_num_blocks
+=
num_blocks
+
fixed_overhead_num_blocks
;
num_blocks_shared
[
i
]
=
num_blocks
;
}
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
total_num_blocks
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
total_num_blocks
,
offset
);
}
__syncwarp
();
if
(
threadIdx
.
x
==
0
)
{
int
payload
=
cutlass
::
ceil_div
(
total_num_blocks
,
num_sm_parts
)
+
fixed_overhead_num_blocks
;
int
now_idx
=
0
,
now_block
=
0
,
now_n_split_idx
=
0
,
cum_num_splits
=
0
;
num_splits_shared
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_sm_parts
;
++
i
)
{
int
tile_scheduler_metadata0
[
4
],
tile_scheduler_metadata1
;
tile_scheduler_metadata0
[
0
]
=
now_idx
;
tile_scheduler_metadata0
[
1
]
=
now_block
*
block_size_n
;
tile_scheduler_metadata1
=
now_n_split_idx
;
int
remain_payload
=
payload
;
while
(
now_idx
<
batch_size
)
{
int
num_blocks
=
num_blocks_shared
[
now_idx
];
int
now_remain_blocks
=
num_blocks
-
now_block
;
if
(
remain_payload
>=
now_remain_blocks
+
fixed_overhead_num_blocks
)
{
cum_num_splits
+=
now_n_split_idx
+
1
;
num_splits_shared
[
now_idx
+
1
]
=
cum_num_splits
;
remain_payload
-=
now_remain_blocks
+
fixed_overhead_num_blocks
;
++
now_idx
;
now_block
=
0
;
now_n_split_idx
=
0
;
}
else
{
if
(
remain_payload
-
fixed_overhead_num_blocks
>
0
)
{
now_block
+=
remain_payload
-
fixed_overhead_num_blocks
;
++
now_n_split_idx
;
remain_payload
=
0
;
}
break
;
}
}
tile_scheduler_metadata0
[
2
]
=
now_block
>
0
?
now_idx
:
now_idx
-
1
;
tile_scheduler_metadata0
[
3
]
=
now_block
>
0
?
now_block
*
block_size_n
:
seqlens_k_ptr
[
now_idx
-
1
];
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
+
i
*
TileSchedulerMetaDataSize
)
=
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata0
);
tile_scheduler_metadata_ptr
[
i
*
TileSchedulerMetaDataSize
+
4
]
=
tile_scheduler_metadata1
;
}
FLASH_DEVICE_ASSERT
(
now_idx
==
batch_size
&&
now_block
==
0
&&
now_n_split_idx
==
0
);
}
__syncwarp
();
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch_size
;
i
+=
32
)
{
num_splits_ptr
[
i
]
=
num_splits_shared
[
i
];
}
}
void
get_mla_metadata_func
(
Mla_metadata_params
&
params
,
cudaStream_t
stream
)
{
FLASH_ASSERT
(
params
.
batch_size
<
MaxBatchSize
);
get_mla_metadata_kernel
<<<
1
,
32
,
0
,
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
}
\ No newline at end of file
setup.py
View file @
65fb7732
...
@@ -37,6 +37,8 @@ ext_modules.append(
...
@@ -37,6 +37,8 @@ ext_modules.append(
sources
=
[
sources
=
[
"csrc/flash_api.cpp"
,
"csrc/flash_api.cpp"
,
"csrc/flash_fwd_mla_bf16_sm90.cu"
,
"csrc/flash_fwd_mla_bf16_sm90.cu"
,
"csrc/flash_fwd_mla_fp16_sm90.cu"
,
"csrc/flash_fwd_mla_metadata.cu"
,
],
],
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
cxx_args
,
"cxx"
:
cxx_args
,
...
...
tests/test_flash_mla.py
View file @
65fb7732
import
argparse
import
math
import
math
import
random
import
random
import
torch
import
torch
import
triton
import
triton
from
flash_mla
import
get_mla_metadata
,
flash_mla_with_kvcache
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
):
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
):
...
@@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
...
@@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
):
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
):
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,),
mean_sk
,
dtype
=
torch
.
int32
)
cache_seqlens
=
torch
.
full
((
b
,),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
if
varlen
:
...
@@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
...
@@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
for
i
in
range
(
b
):
for
i
in
range
(
b
):
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
():]
=
float
(
"nan"
)
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
()
:]
=
(
float
(
"nan"
)
)
blocked_v
=
blocked_k
[...,
:
dv
]
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
def
flash_mla
():
def
flash_mla
():
return
flash_mla_with_kvcache
(
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
q
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
)
)
def
ref_mla
():
def
ref_mla
():
...
@@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
...
@@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
torch
.
finfo
(
q
.
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
if
__name__
==
"__main__"
:
def
main
(
torch_dtype
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_
dtype
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
@@ -114,3 +132,22 @@ if __name__ == "__main__":
...
@@ -114,3 +132,22 @@ if __name__ == "__main__":
for
s_q
in
[
1
,
2
]:
# MTP = 1, 2
for
s_q
in
[
1
,
2
]:
# MTP = 1, 2
for
varlen
in
[
False
,
True
]:
for
varlen
in
[
False
,
True
]:
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bf16"
,
"fp16"
],
default
=
"bf16"
,
help
=
"Data type to use for testing (bf16 or fp16)"
,
)
args
=
parser
.
parse_args
()
torch_dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"fp16"
:
torch_dtype
=
torch
.
float16
main
(
torch_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment