Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
b549289f
Unverified
Commit
b549289f
authored
Feb 25, 2025
by
Jiashi Li
Committed by
GitHub
Feb 25, 2025
Browse files
Merge pull request #32 from sijiac/fp16-support
Support FP16 dtype in FlashMLA kenrel
parents
18e32770
e1e9fa98
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
164 additions
and
98 deletions
+164
-98
README.md
README.md
+1
-1
csrc/flash_api.cpp
csrc/flash_api.cpp
+13
-3
csrc/flash_fwd_mla_fp16_sm90.cu
csrc/flash_fwd_mla_fp16_sm90.cu
+3
-0
csrc/flash_fwd_mla_kernel.h
csrc/flash_fwd_mla_kernel.h
+0
-76
csrc/flash_fwd_mla_metadata.cu
csrc/flash_fwd_mla_metadata.cu
+77
-0
setup.py
setup.py
+21
-6
tests/test_flash_mla.py
tests/test_flash_mla.py
+49
-12
No files found.
README.md
View file @
b549289f
...
...
@@ -3,7 +3,7 @@
FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.
Currently released:
-
BF16
-
BF16
, FP16
-
Paged kvcache with block size of 64
## Quick start
...
...
csrc/flash_api.cpp
View file @
b549289f
...
...
@@ -61,7 +61,7 @@ std::vector<at::Tensor>
mha_fwd_kvcache_mla
(
at
::
Tensor
&
q
,
// batch_size x seqlen_q x num_heads x head_size
const
at
::
Tensor
&
kcache
,
// num_blocks x page_block_size x num_heads_k x head_size
c10
::
optional
<
const
at
::
Tensor
>
&
vcache_
,
// num_blocks x page_block_size x num_heads_k x head_size_v
std
::
optional
<
const
at
::
Tensor
>
&
vcache_
,
// num_blocks x page_block_size x num_heads_k x head_size_v
const
int
head_size_v
,
const
at
::
Tensor
&
seqlens_k
,
// batch_size
const
at
::
Tensor
&
block_table
,
// batch_size x max_num_blocks_per_seq
...
...
@@ -77,7 +77,6 @@ mha_fwd_kvcache_mla(
at
::
Tensor
vcache
=
vcache_
.
has_value
()
?
vcache_
.
value
()
:
kcache
;
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kBFloat16
);
TORCH_CHECK
(
kcache
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
kcache
);
CHECK_DEVICE
(
vcache
);
...
...
@@ -186,7 +185,18 @@ mha_fwd_kvcache_mla(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
head_size
==
576
);
run_mha_fwd_splitkv_mla
<
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
);
if
(
q_dtype
==
torch
::
kBFloat16
)
{
run_mha_fwd_splitkv_mla
<
cutlass
::
bfloat16_t
,
576
>
(
params
,
stream
);
}
#ifndef FLASH_MLA_DISABLE_FP16
else
if
(
q_dtype
==
torch
::
kHalf
)
{
run_mha_fwd_splitkv_mla
<
cutlass
::
half_t
,
576
>
(
params
,
stream
);
}
#endif
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
}
out
=
out
.
view
({
batch_size
,
seqlen_q_ori
,
ngroups
,
num_heads_k
,
head_size_v
}).
transpose
(
2
,
3
)
.
reshape
({
batch_size
,
seqlen_q_ori
,
num_heads_ori
,
head_size_v
});
...
...
csrc/flash_fwd_mla_fp16_sm90.cu
0 → 100644
View file @
b549289f
#include "flash_fwd_mla_kernel.h"
template
void
run_mha_fwd_splitkv_mla
<
cutlass
::
half_t
,
576
>(
Flash_fwd_mla_params
&
params
,
cudaStream_t
stream
);
csrc/flash_fwd_mla_kernel.h
View file @
b549289f
...
...
@@ -601,79 +601,3 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream)
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla
<
576
,
64
,
64
,
8
,
T
,
512
>
;
run_flash_splitkv_fwd_mla
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>>
(
params
,
stream
);
}
static
constexpr
int
MaxBatchSize
=
4096
;
__global__
void
__launch_bounds__
(
256
,
1
,
1
)
get_mla_metadata_kernel
(
__grid_constant__
const
Mla_metadata_params
params
)
{
int
*
seqlens_k_ptr
=
params
.
seqlens_k_ptr
;
int
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
;
int
*
num_splits_ptr
=
params
.
num_splits_ptr
;
int
batch_size
=
params
.
batch_size
;
int
block_size_n
=
params
.
block_size_n
;
int
fixed_overhead_num_blocks
=
params
.
fixed_overhead_num_blocks
;
int
num_sm_parts
=
params
.
num_sm_parts
;
__shared__
int
num_blocks_shared
[
MaxBatchSize
];
__shared__
int
num_splits_shared
[
MaxBatchSize
];
int
total_num_blocks
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
32
)
{
int
num_blocks
=
cutlass
::
ceil_div
(
seqlens_k_ptr
[
i
],
block_size_n
);
total_num_blocks
+=
num_blocks
+
fixed_overhead_num_blocks
;
num_blocks_shared
[
i
]
=
num_blocks
;
}
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
total_num_blocks
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
total_num_blocks
,
offset
);
}
__syncwarp
();
if
(
threadIdx
.
x
==
0
)
{
int
payload
=
cutlass
::
ceil_div
(
total_num_blocks
,
num_sm_parts
)
+
fixed_overhead_num_blocks
;
int
now_idx
=
0
,
now_block
=
0
,
now_n_split_idx
=
0
,
cum_num_splits
=
0
;
num_splits_shared
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_sm_parts
;
++
i
)
{
int
tile_scheduler_metadata0
[
4
],
tile_scheduler_metadata1
;
tile_scheduler_metadata0
[
0
]
=
now_idx
;
tile_scheduler_metadata0
[
1
]
=
now_block
*
block_size_n
;
tile_scheduler_metadata1
=
now_n_split_idx
;
int
remain_payload
=
payload
;
while
(
now_idx
<
batch_size
)
{
int
num_blocks
=
num_blocks_shared
[
now_idx
];
int
now_remain_blocks
=
num_blocks
-
now_block
;
if
(
remain_payload
>=
now_remain_blocks
+
fixed_overhead_num_blocks
)
{
cum_num_splits
+=
now_n_split_idx
+
1
;
num_splits_shared
[
now_idx
+
1
]
=
cum_num_splits
;
remain_payload
-=
now_remain_blocks
+
fixed_overhead_num_blocks
;
++
now_idx
;
now_block
=
0
;
now_n_split_idx
=
0
;
}
else
{
if
(
remain_payload
-
fixed_overhead_num_blocks
>
0
)
{
now_block
+=
remain_payload
-
fixed_overhead_num_blocks
;
++
now_n_split_idx
;
remain_payload
=
0
;
}
break
;
}
}
tile_scheduler_metadata0
[
2
]
=
now_block
>
0
?
now_idx
:
now_idx
-
1
;
tile_scheduler_metadata0
[
3
]
=
now_block
>
0
?
now_block
*
block_size_n
:
seqlens_k_ptr
[
now_idx
-
1
];
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
+
i
*
TileSchedulerMetaDataSize
)
=
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata0
);
tile_scheduler_metadata_ptr
[
i
*
TileSchedulerMetaDataSize
+
4
]
=
tile_scheduler_metadata1
;
}
FLASH_DEVICE_ASSERT
(
now_idx
==
batch_size
&&
now_block
==
0
&&
now_n_split_idx
==
0
);
}
__syncwarp
();
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch_size
;
i
+=
32
)
{
num_splits_ptr
[
i
]
=
num_splits_shared
[
i
];
}
}
void
get_mla_metadata_func
(
Mla_metadata_params
&
params
,
cudaStream_t
stream
)
{
FLASH_ASSERT
(
params
.
batch_size
<
MaxBatchSize
);
get_mla_metadata_kernel
<<<
1
,
32
,
0
,
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
}
csrc/flash_fwd_mla_metadata.cu
0 → 100644
View file @
b549289f
#include "flash_fwd_mla_kernel.h"
static
constexpr
int
MaxBatchSize
=
4096
;
__global__
void
__launch_bounds__
(
256
,
1
,
1
)
get_mla_metadata_kernel
(
__grid_constant__
const
Mla_metadata_params
params
)
{
int
*
seqlens_k_ptr
=
params
.
seqlens_k_ptr
;
int
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
;
int
*
num_splits_ptr
=
params
.
num_splits_ptr
;
int
batch_size
=
params
.
batch_size
;
int
block_size_n
=
params
.
block_size_n
;
int
fixed_overhead_num_blocks
=
params
.
fixed_overhead_num_blocks
;
int
num_sm_parts
=
params
.
num_sm_parts
;
__shared__
int
num_blocks_shared
[
MaxBatchSize
];
__shared__
int
num_splits_shared
[
MaxBatchSize
];
int
total_num_blocks
=
0
;
for
(
int
i
=
threadIdx
.
x
;
i
<
batch_size
;
i
+=
32
)
{
int
num_blocks
=
cutlass
::
ceil_div
(
seqlens_k_ptr
[
i
],
block_size_n
);
total_num_blocks
+=
num_blocks
+
fixed_overhead_num_blocks
;
num_blocks_shared
[
i
]
=
num_blocks
;
}
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
total_num_blocks
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
total_num_blocks
,
offset
);
}
__syncwarp
();
if
(
threadIdx
.
x
==
0
)
{
int
payload
=
cutlass
::
ceil_div
(
total_num_blocks
,
num_sm_parts
)
+
fixed_overhead_num_blocks
;
int
now_idx
=
0
,
now_block
=
0
,
now_n_split_idx
=
0
,
cum_num_splits
=
0
;
num_splits_shared
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
num_sm_parts
;
++
i
)
{
int
tile_scheduler_metadata0
[
4
],
tile_scheduler_metadata1
;
tile_scheduler_metadata0
[
0
]
=
now_idx
;
tile_scheduler_metadata0
[
1
]
=
now_block
*
block_size_n
;
tile_scheduler_metadata1
=
now_n_split_idx
;
int
remain_payload
=
payload
;
while
(
now_idx
<
batch_size
)
{
int
num_blocks
=
num_blocks_shared
[
now_idx
];
int
now_remain_blocks
=
num_blocks
-
now_block
;
if
(
remain_payload
>=
now_remain_blocks
+
fixed_overhead_num_blocks
)
{
cum_num_splits
+=
now_n_split_idx
+
1
;
num_splits_shared
[
now_idx
+
1
]
=
cum_num_splits
;
remain_payload
-=
now_remain_blocks
+
fixed_overhead_num_blocks
;
++
now_idx
;
now_block
=
0
;
now_n_split_idx
=
0
;
}
else
{
if
(
remain_payload
-
fixed_overhead_num_blocks
>
0
)
{
now_block
+=
remain_payload
-
fixed_overhead_num_blocks
;
++
now_n_split_idx
;
remain_payload
=
0
;
}
break
;
}
}
tile_scheduler_metadata0
[
2
]
=
now_block
>
0
?
now_idx
:
now_idx
-
1
;
tile_scheduler_metadata0
[
3
]
=
now_block
>
0
?
now_block
*
block_size_n
:
seqlens_k_ptr
[
now_idx
-
1
];
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata_ptr
+
i
*
TileSchedulerMetaDataSize
)
=
*
reinterpret_cast
<
int4
*>
(
tile_scheduler_metadata0
);
tile_scheduler_metadata_ptr
[
i
*
TileSchedulerMetaDataSize
+
4
]
=
tile_scheduler_metadata1
;
}
FLASH_DEVICE_ASSERT
(
now_idx
==
batch_size
&&
now_block
==
0
&&
now_n_split_idx
==
0
);
}
__syncwarp
();
for
(
int
i
=
threadIdx
.
x
;
i
<=
batch_size
;
i
+=
32
)
{
num_splits_ptr
[
i
]
=
num_splits_shared
[
i
];
}
}
void
get_mla_metadata_func
(
Mla_metadata_params
&
params
,
cudaStream_t
stream
)
{
FLASH_ASSERT
(
params
.
batch_size
<
MaxBatchSize
);
get_mla_metadata_kernel
<<<
1
,
32
,
0
,
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
}
\ No newline at end of file
setup.py
View file @
b549289f
...
...
@@ -11,11 +11,29 @@ from torch.utils.cpp_extension import (
IS_WINDOWS
,
)
DISABLE_FP16
=
os
.
getenv
(
"FLASH_MLA_DISABLE_FP16"
,
"FALSE"
)
==
"TRUE"
def
append_nvcc_threads
(
nvcc_extra_args
):
nvcc_threads
=
os
.
getenv
(
"NVCC_THREADS"
)
or
"32"
return
nvcc_extra_args
+
[
"--threads"
,
nvcc_threads
]
def
get_sources
():
sources
=
[
"csrc/flash_api.cpp"
,
"csrc/flash_fwd_mla_bf16_sm90.cu"
,
"csrc/flash_fwd_mla_metadata.cu"
,
]
if
not
DISABLE_FP16
:
sources
.
append
(
"csrc/flash_fwd_mla_fp16_sm90.cu"
)
return
sources
def
get_features_args
():
features_args
=
[]
if
DISABLE_FP16
:
features_args
.
append
(
"-DFLASH_MLA_DISABLE_FP16"
)
return
features_args
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"csrc/cutlass"
])
...
...
@@ -34,12 +52,9 @@ ext_modules = []
ext_modules
.
append
(
CUDAExtension
(
name
=
"flash_mla_cuda"
,
sources
=
[
"csrc/flash_api.cpp"
,
"csrc/flash_fwd_mla_bf16_sm90.cu"
,
],
sources
=
get_sources
(),
extra_compile_args
=
{
"cxx"
:
cxx_args
,
"cxx"
:
cxx_args
+
get_features_args
()
,
"nvcc"
:
append_nvcc_threads
(
[
"-O3"
,
...
...
@@ -57,7 +72,7 @@ ext_modules.append(
"--ptxas-options=-v,--register-usage-level=10"
]
+
cc_flag
),
)
+
get_features_args
()
,
},
include_dirs
=
[
Path
(
this_dir
)
/
"csrc"
,
...
...
tests/test_flash_mla.py
View file @
b549289f
import
argparse
import
math
import
random
import
torch
import
triton
from
flash_mla
import
get_mla_metadata
,
flash_mla_with_kvcache
from
flash_mla
import
flash_mla_with_kvcache
,
get_mla_metadata
def
scaled_dot_product_attention
(
query
,
key
,
value
,
h_q
,
h_kv
,
is_causal
=
False
):
...
...
@@ -38,7 +39,9 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
):
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
,
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
"
)
cache_seqlens
=
torch
.
full
((
b
,),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
...
...
@@ -52,18 +55,30 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_size
=
64
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
for
i
in
range
(
b
):
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
():]
=
float
(
"nan"
)
blocked_k
.
view
(
b
,
max_seqlen_pad
,
h_kv
,
d
)[
i
,
cache_seqlens
[
i
].
item
():]
=
(
float
(
"nan"
)
)
blocked_v
=
blocked_k
[...,
:
dv
]
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
q
,
blocked_k
,
block_table
,
cache_seqlens
,
dv
,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
)
def
ref_mla
():
...
...
@@ -91,14 +106,17 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
q
.
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
if
__name__
==
"__main__"
:
dtype
=
torch
.
bfloat16
def
main
(
torch_dtype
):
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
torch
.
set_default_dtype
(
torch_
dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
...
...
@@ -114,3 +132,22 @@ if __name__ == "__main__":
for
s_q
in
[
1
,
2
]:
# MTP = 1, 2
for
varlen
in
[
False
,
True
]:
test_flash_mla
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bf16"
,
"fp16"
],
default
=
"bf16"
,
help
=
"Data type to use for testing (bf16 or fp16)"
,
)
args
=
parser
.
parse_args
()
torch_dtype
=
torch
.
bfloat16
if
args
.
dtype
==
"fp16"
:
torch_dtype
=
torch
.
float16
main
(
torch_dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment