Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
889722f3
Unverified
Commit
889722f3
authored
Jan 21, 2026
by
Lucas Wilkinson
Committed by
GitHub
Jan 21, 2026
Browse files
[FlashMLA] Update FlashMLA to expose new arguments (#32810)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
49d96538
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
133 additions
and
217 deletions
+133
-217
.gitignore
.gitignore
+3
-0
cmake/external_projects/flashmla.cmake
cmake/external_projects/flashmla.cmake
+19
-2
setup.py
setup.py
+10
-0
tests/v1/attention/test_sparse_mla_backends.py
tests/v1/attention/test_sparse_mla_backends.py
+0
-1
vllm/third_party/flashmla/__init__.py
vllm/third_party/flashmla/__init__.py
+1
-0
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+45
-50
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+8
-29
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+47
-135
No files found.
.gitignore
View file @
889722f3
...
@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
...
@@ -7,6 +7,9 @@ vllm/vllm_flash_attn/*
# OpenAI triton kernels copied from source
# OpenAI triton kernels copied from source
vllm/third_party/triton_kernels/*
vllm/third_party/triton_kernels/*
# FlashMLA interface copied from source
vllm/third_party/flashmla/flash_mla_interface.py
# triton jit
# triton jit
.triton
.triton
...
...
cmake/external_projects/flashmla.cmake
View file @
889722f3
...
@@ -19,7 +19,7 @@ else()
...
@@ -19,7 +19,7 @@ else()
FetchContent_Declare
(
FetchContent_Declare
(
flashmla
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG
526781394b33d9888e4c41952e692266267dd8bf
GIT_TAG
c2afa9cb93e674d5a9120a170a6da57b89267208
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
CONFIGURE_COMMAND
""
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
BUILD_COMMAND
""
...
@@ -30,6 +30,24 @@ endif()
...
@@ -30,6 +30,24 @@ endif()
FetchContent_MakeAvailable
(
flashmla
)
FetchContent_MakeAvailable
(
flashmla
)
message
(
STATUS
"FlashMLA is available at
${
flashmla_SOURCE_DIR
}
"
)
message
(
STATUS
"FlashMLA is available at
${
flashmla_SOURCE_DIR
}
"
)
# Vendor FlashMLA interface into vLLM with torch-ops shim.
set
(
FLASHMLA_VENDOR_DIR
"
${
CMAKE_SOURCE_DIR
}
/vllm/third_party/flashmla"
)
file
(
MAKE_DIRECTORY
"
${
FLASHMLA_VENDOR_DIR
}
"
)
file
(
READ
"
${
flashmla_SOURCE_DIR
}
/flash_mla/flash_mla_interface.py"
FLASHMLA_INTERFACE_CONTENT
)
string
(
REPLACE
"import flash_mla.cuda as flash_mla_cuda"
"import vllm._flashmla_C
\n
flash_mla_cuda = torch.ops._flashmla_C"
FLASHMLA_INTERFACE_CONTENT
"
${
FLASHMLA_INTERFACE_CONTENT
}
"
)
file
(
WRITE
"
${
FLASHMLA_VENDOR_DIR
}
/flash_mla_interface.py"
"
${
FLASHMLA_INTERFACE_CONTENT
}
"
)
# Install the generated flash_mla_interface.py to the wheel
# Use COMPONENT _flashmla_C to ensure it's installed with the C extension
install
(
FILES
"
${
FLASHMLA_VENDOR_DIR
}
/flash_mla_interface.py"
DESTINATION vllm/third_party/flashmla/
COMPONENT _flashmla_C
)
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
# sm90a
...
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
...
@@ -79,7 +97,6 @@ if(FLASH_MLA_ARCHS)
# sm100 dense prefill & backward
# sm100 dense prefill & backward
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
# sm100 sparse prefill
# sm100 sparse prefill
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
${
flashmla_SOURCE_DIR
}
/csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu
...
...
setup.py
View file @
889722f3
...
@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
...
@@ -646,6 +646,9 @@ class precompiled_wheel_utils:
triton_kernels_regex
=
re
.
compile
(
triton_kernels_regex
=
re
.
compile
(
r
"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
r
"vllm/third_party/triton_kernels/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
)
flashmla_regex
=
re
.
compile
(
r
"vllm/third_party/flashmla/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py"
)
file_members
=
list
(
file_members
=
list
(
filter
(
lambda
x
:
x
.
filename
in
files_to_copy
,
wheel
.
filelist
)
filter
(
lambda
x
:
x
.
filename
in
files_to_copy
,
wheel
.
filelist
)
)
)
...
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
...
@@ -657,6 +660,9 @@ class precompiled_wheel_utils:
lambda
x
:
triton_kernels_regex
.
match
(
x
.
filename
),
wheel
.
filelist
lambda
x
:
triton_kernels_regex
.
match
(
x
.
filename
),
wheel
.
filelist
)
)
)
)
file_members
+=
list
(
filter
(
lambda
x
:
flashmla_regex
.
match
(
x
.
filename
),
wheel
.
filelist
)
)
for
file
in
file_members
:
for
file
in
file_members
:
print
(
f
"[extract]
{
file
.
filename
}
"
)
print
(
f
"[extract]
{
file
.
filename
}
"
)
...
@@ -925,6 +931,10 @@ if _is_cuda():
...
@@ -925,6 +931,10 @@ if _is_cuda():
):
):
# FA3 requires CUDA 12.3 or later
# FA3 requires CUDA 12.3 or later
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa3_C"
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm.vllm_flash_attn._vllm_fa3_C"
))
if
envs
.
VLLM_USE_PRECOMPILED
or
(
CUDA_HOME
and
get_nvcc_cuda_version
()
>=
Version
(
"12.9"
)
):
# FlashMLA requires CUDA 12.9 or later
# Optional since this doesn't get built (produce an .so file) when
# Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system
# not targeting a hopper system
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._flashmla_C"
,
optional
=
True
))
ext_modules
.
append
(
CMakeExtension
(
name
=
"vllm._flashmla_C"
,
optional
=
True
))
...
...
tests/v1/attention/test_sparse_mla_backends.py
View file @
889722f3
...
@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
...
@@ -53,7 +53,6 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
def
_float_to_e8m0_truncate
(
f
:
float
)
->
float
:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2).
e8m0 format only stores the exponent (power of 2).
cudaRoundZero truncates toward zero, meaning we round down to the
cudaRoundZero truncates toward zero, meaning we round down to the
nearest power of 2.
nearest power of 2.
...
...
vllm/third_party/flashmla/__init__.py
0 → 100644
View file @
889722f3
# Sources copied from FlashMLA
vllm/v1/attention/backends/mla/flashmla.py
View file @
889722f3
...
@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
...
@@ -32,8 +32,11 @@ from vllm.v1.attention.backends.utils import (
reshape_query_for_spec_decode
,
reshape_query_for_spec_decode
,
)
)
from
vllm.v1.attention.ops.flashmla
import
(
from
vllm.v1.attention.ops.flashmla
import
(
FlashMLASchedMeta
,
flash_mla_with_kvcache
,
flash_mla_with_kvcache
,
flash_mla_with_kvcache_fp8
,
get_mla_metadata
,
get_mla_metadata
,
get_mla_metadata_dense_fp8
,
is_flashmla_dense_supported
,
is_flashmla_dense_supported
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
...
@@ -93,8 +96,7 @@ class FlashMLABackend(MLACommonBackend):
@
dataclass
@
dataclass
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
class
FlashMLADecodeMetadata
(
MLACommonDecodeMetadata
):
tile_scheduler_metadata
:
torch
.
Tensor
scheduler_metadata
:
FlashMLASchedMeta
num_splits
:
torch
.
Tensor
@
dataclass
@
dataclass
...
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -158,46 +160,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# we use the max but all should be the same due to uniform length requirement
# we use the max but all should be the same due to uniform length requirement
max_query_len
=
query_lens_cpu
.
max
().
item
()
max_query_len
=
query_lens_cpu
.
max
().
item
()
num_q_tokens_per_head_k
=
max_query_len
*
self
.
num_q_heads
//
1
num_q_tokens_per_head_k
=
max_query_len
*
self
.
num_q_heads
//
1
tile_
scheduler_metadata
,
num_splits
=
get_mla_metadata
(
scheduler_metadata
,
_
=
get_mla_metadata
(
seq_lens_device
,
seq_lens_device
,
num_q_tokens_per_head_k
,
num_q_tokens_per_head_k
,
1
,
# MQA for the decode path
1
,
# MQA for the decode path
is_fp8_kvcache
=
self
.
is_fp8_kvcache
,
is_fp8_kvcache
=
self
.
is_fp8_kvcache
,
)
)
if
self
.
is_fp8_kvcache
:
# TODO: we can disambiguate between decode and mixed-prefill decode here
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata_dense_fp8
(
# so we can only use the persistent buffer if a cudagraph is actually
seq_lens_device
,
# being used.
num_q_tokens_per_head_k
,
if
self
.
compilation_config
.
cudagraph_mode
.
has_full_cudagraphs
():
1
,
# MQA for the decode path
assert
self
.
cg_buf_tile_scheduler_metadata
is
not
None
)
assert
self
.
cg_buf_num_splits
is
not
None
scheduler_metadata
.
tile_scheduler_metadata
=
tile_scheduler_metadata
scheduler_metadata
.
num_splits
=
num_splits
sm_parts
=
tile_scheduler_metadata
.
size
(
0
)
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
assert
sm_parts
<=
self
.
cg_buf_tile_scheduler_metadata
.
size
(
0
)
tile_scheduler_metadata_view
=
self
.
cg_buf_tile_scheduler_metadata
[
:
sm_parts
]
tile_scheduler_metadata_view
.
copy_
(
tile_scheduler_metadata
)
tile_scheduler_metadata
=
tile_scheduler_metadata_view
# Num splits is per-batch, varying size (batch_size,)
n
=
num_splits
.
size
(
0
)
# make sure static buffer is large enough
assert
n
<=
self
.
cg_buf_num_splits
.
size
(
0
)
num_splits_view
=
self
.
cg_buf_num_splits
[:
n
]
num_splits_view
.
copy_
(
num_splits
)
# Num splits needs to monotonically increasing
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
# it needs to monotonically increasing by 1)
self
.
cg_buf_num_splits
[
n
:].
fill_
(
num_splits
[
-
1
])
num_splits
=
num_splits_view
return
FlashMLADecodeMetadata
(
return
FlashMLADecodeMetadata
(
block_table
=
block_table_tensor
,
block_table
=
block_table_tensor
,
seq_lens
=
seq_lens_device
,
seq_lens
=
seq_lens_device
,
tile_scheduler_metadata
=
tile_scheduler_metadata
,
scheduler_metadata
=
scheduler_metadata
,
num_splits
=
num_splits
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
dcp_tot_seq_lens
=
dcp_tot_seq_lens_device
,
)
)
...
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -272,9 +253,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_decodes
=
attn_metadata
.
num_decodes
num_decodes
=
attn_metadata
.
num_decodes
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
tile_scheduler_metadata
=
attn_metadata
.
decode
.
tile_scheduler_metadata
scheduler_metadata
=
attn_metadata
.
decode
.
scheduler_metadata
num_splits
=
attn_metadata
.
decode
.
num_splits
if
vllm_is_batch_invariant
()
and
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
vllm_is_batch_invariant
():
device
=
q
.
device
device
=
q
.
device
dtype
=
torch
.
int32
dtype
=
torch
.
int32
...
@@ -301,20 +281,35 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -301,20 +281,35 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# Non-split path ignores num_splits, but the API requires it:
# Non-split path ignores num_splits, but the API requires it:
# zeros of length B+1
# zeros of length B+1
num_splits
=
torch
.
zeros
((
B
+
1
,),
dtype
=
dtype
,
device
=
device
)
num_splits
=
torch
.
zeros
((
B
+
1
,),
dtype
=
dtype
,
device
=
device
)
scheduler_metadata
.
tile_scheduler_metadata
=
tile_scheduler_metadata
scheduler_metadata
.
num_splits
=
num_splits
o
,
lse
=
flash_mla_with_kvcache
(
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
o
,
lse
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
tile_scheduler_metadata
,
tile_scheduler_metadata
=
scheduler_metadata
.
tile_scheduler_metadata
,
num_splits
=
num_splits
,
num_splits
=
scheduler_metadata
.
num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
)
)
else
:
o
,
lse
=
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
attn_metadata
.
decode
.
block_table
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
scheduler_metadata
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
is_fp8_kvcache
=
False
,
)
o
=
reshape_attn_output_for_spec_decode
(
o
)
o
=
reshape_attn_output_for_spec_decode
(
o
)
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
889722f3
...
@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
...
@@ -33,7 +33,8 @@ from vllm.v1.attention.backends.utils import (
split_prefill_chunks
,
split_prefill_chunks
,
)
)
from
vllm.v1.attention.ops.flashmla
import
(
from
vllm.v1.attention.ops.flashmla
import
(
flash_mla_sparse_prefill
,
FlashMLASchedMeta
,
flash_mla_sparse_fwd
,
flash_mla_with_kvcache
,
flash_mla_with_kvcache
,
get_mla_metadata
,
get_mla_metadata
,
)
)
...
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
...
@@ -142,8 +143,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
@
dataclass
@
dataclass
class
FP8KernelMetadata
:
class
FP8KernelMetadata
:
scheduler_metadata
:
torch
.
Tensor
|
None
scheduler_metadata
:
FlashMLASchedMeta
num_splits
:
torch
.
Tensor
dummy_block_table
:
torch
.
Tensor
dummy_block_table
:
torch
.
Tensor
cache_lens
:
torch
.
Tensor
cache_lens
:
torch
.
Tensor
...
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -468,7 +468,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
padded_heads
=
self
.
fp8_decode_padded_heads
padded_heads
=
self
.
fp8_decode_padded_heads
# Build metadata for all tokens as a single batch
# Build metadata for all tokens as a single batch
tile_
scheduler_metadata
,
num_splits
=
get_mla_metadata
(
scheduler_metadata
,
_
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
[:
1
],
# Single batch
cache_seqlens
=
self
.
topk_tokens_tensor
[:
1
],
# Single batch
num_q_tokens_per_head_k
=
num_tokens
*
padded_heads
,
num_q_tokens_per_head_k
=
num_tokens
*
padded_heads
,
topk
=
self
.
topk_tokens
,
topk
=
self
.
topk_tokens
,
...
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -477,17 +477,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache
=
True
,
is_fp8_kvcache
=
True
,
)
)
num_sm_parts
=
tile_scheduler_metadata
.
size
(
0
)
tile_scheduler_metadata_buffer
=
self
.
tile_scheduler_metadata_buffer
[
:
num_sm_parts
]
tile_scheduler_metadata_buffer
.
copy_
(
tile_scheduler_metadata
)
num_splits_view
=
self
.
num_splits_buffer
[:
2
]
num_splits_view
.
copy_
(
num_splits
)
fp8_metadata
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
fp8_metadata
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
scheduler_metadata
=
tile_scheduler_metadata_buffer
,
scheduler_metadata
=
scheduler_metadata
,
num_splits
=
num_splits_view
,
cache_lens
=
self
.
max_model_len_tensor
[:
1
],
cache_lens
=
self
.
max_model_len_tensor
[:
1
],
dummy_block_table
=
self
.
dummy_block_table
[:
1
],
dummy_block_table
=
self
.
dummy_block_table
[:
1
],
)
)
...
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -620,7 +611,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
# Use padded head count since that's what the kernel will see
# Use padded head count since that's what the kernel will see
padded_heads
=
self
.
fp8_decode_padded_heads
padded_heads
=
self
.
fp8_decode_padded_heads
tile_
scheduler_metadata
,
num_splits
=
get_mla_metadata
(
scheduler_metadata
,
_
=
get_mla_metadata
(
cache_seqlens
=
self
.
topk_tokens_tensor
[:
num_decodes
],
cache_seqlens
=
self
.
topk_tokens_tensor
[:
num_decodes
],
num_q_tokens_per_head_k
=
decode_query_len
*
padded_heads
,
num_q_tokens_per_head_k
=
decode_query_len
*
padded_heads
,
topk
=
self
.
topk_tokens
,
topk
=
self
.
topk_tokens
,
...
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
...
@@ -629,19 +620,8 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
is_fp8_kvcache
=
True
,
is_fp8_kvcache
=
True
,
)
)
num_sm_parts
=
tile_scheduler_metadata
.
size
(
0
)
# Copy to persistent buffer for full-CG support
tile_scheduler_metadata_buffer
=
self
.
tile_scheduler_metadata_buffer
[
:
num_sm_parts
]
tile_scheduler_metadata_buffer
.
copy_
(
tile_scheduler_metadata
)
# num_splits has size [num_decodes + 1]
num_splits_view
=
self
.
num_splits_buffer
[:
num_decodes
+
1
]
num_splits_view
.
copy_
(
num_splits
)
kernel_meta
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
kernel_meta
=
FlashMLASparseMetadata
.
FP8KernelMetadata
(
scheduler_metadata
=
tile_scheduler_metadata_buffer
,
scheduler_metadata
=
scheduler_metadata
,
num_splits
=
num_splits_view
,
dummy_block_table
=
self
.
dummy_block_table
[:
num_decodes
],
dummy_block_table
=
self
.
dummy_block_table
[:
num_decodes
],
cache_lens
=
self
.
max_model_len_tensor
[:
num_decodes
],
cache_lens
=
self
.
max_model_len_tensor
[:
num_decodes
],
)
)
...
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -949,7 +929,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
head_dim_v
=
512
,
head_dim_v
=
512
,
cache_seqlens
=
kernel_metadata
.
cache_lens
,
cache_seqlens
=
kernel_metadata
.
cache_lens
,
tile_scheduler_metadata
=
kernel_metadata
.
scheduler_metadata
,
tile_scheduler_metadata
=
kernel_metadata
.
scheduler_metadata
,
num_splits
=
kernel_metadata
.
num_splits
,
is_fp8_kvcache
=
True
,
is_fp8_kvcache
=
True
,
indices
=
topk_indices
,
indices
=
topk_indices
,
softmax_scale
=
self
.
softmax_scale
,
softmax_scale
=
self
.
softmax_scale
,
...
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
...
@@ -985,7 +964,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q
=
q_padded
q
=
q_padded
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
topk_indices
=
topk_indices
.
view
(
num_tokens
,
1
,
-
1
)
output
=
flash_mla_sparse_
prefill
(
output
=
flash_mla_sparse_
fwd
(
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
self
.
softmax_scale
q
,
kv_c_and_k_pe_cache
,
topk_indices
,
self
.
softmax_scale
)[
0
]
)[
0
]
output
=
output
[:,
:
self
.
num_heads
,
:]
output
=
output
[:,
:
self
.
num_heads
,
:]
...
...
vllm/v1/attention/ops/flashmla.py
View file @
889722f3
...
@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
...
@@ -78,50 +78,49 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
return
True
,
None
return
True
,
None
def
get_mla_metadata
(
def
_raise_flashmla_unavailable
(
*
_args
,
**
_kwargs
):
_
,
reason
=
_is_flashmla_available
()
raise
RuntimeError
(
reason
or
"FlashMLA is not available"
)
if
_is_flashmla_available
()[
0
]:
from
vllm.third_party.flashmla.flash_mla_interface
import
(
# noqa: F401
FlashMLASchedMeta
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_mla_sparse_fwd
,
flash_mla_with_kvcache
,
get_mla_metadata
,
)
else
:
class
FlashMLASchedMeta
:
# type: ignore[no-redef]
pass
flash_attn_varlen_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_kvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_qkvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_mla_sparse_fwd
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_mla_with_kvcache
=
_raise_flashmla_unavailable
# type: ignore[assignment]
get_mla_metadata
=
_raise_flashmla_unavailable
# type: ignore[assignment]
def
get_mla_metadata_dense_fp8
(
cache_seqlens
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
num_q_tokens_per_head_k
:
int
,
num_q_tokens_per_head_k
:
int
,
num_heads_k
:
int
,
num_heads_k
:
int
,
num_heads_q
:
int
|
None
=
None
,
is_fp8_kvcache
:
bool
=
False
,
topk
:
int
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
if
not
_is_flashmla_available
()[
0
]:
Arguments:
_raise_flashmla_unavailable
()
- cache_seqlens: (batch_size), dtype torch.int32.
- num_q_tokens_per_head_k:
Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k.
- num_heads_k: The number of k heads.
- num_heads_q:
The number of q heads.
This argument is optional when sparse attention is not enabled
- is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format.
- topk: If not None, sparse attention will be enabled,
and only tokens in the `indices` array
passed to `flash_mla_with_kvcache_sm90` will be attended to.
Returns:
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if
is_fp8_kvcache
and
topk
is
None
:
return
torch
.
ops
.
_flashmla_extension_C
.
get_mla_decoding_metadata_dense_fp8
(
return
torch
.
ops
.
_flashmla_extension_C
.
get_mla_decoding_metadata_dense_fp8
(
cache_seqlens
,
cache_seqlens
,
num_q_tokens_per_head_k
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_k
,
)
)
return
torch
.
ops
.
_flashmla_C
.
get_mla_decoding_metadata
(
cache_seqlens
,
num_q_tokens_per_head_k
,
num_heads_k
,
num_heads_q
,
is_fp8_kvcache
,
topk
,
)
def
flash_mla_with_kvcache
(
def
flash_mla_with_kvcache
_fp8
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
...
@@ -133,54 +132,11 @@ def flash_mla_with_kvcache(
...
@@ -133,54 +132,11 @@ def flash_mla_with_kvcache(
causal
:
bool
=
False
,
causal
:
bool
=
False
,
descale_q
:
torch
.
Tensor
|
None
=
None
,
descale_q
:
torch
.
Tensor
|
None
=
None
,
descale_k
:
torch
.
Tensor
|
None
=
None
,
descale_k
:
torch
.
Tensor
|
None
=
None
,
is_fp8_kvcache
:
bool
=
False
,
indices
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
if
not
_is_flashmla_available
()[
0
]:
Arguments:
_raise_flashmla_unavailable
()
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
- k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
- block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
- cache_seqlens: (batch_size), torch.int32.
- head_dim_v: Head dimension of v.
- tile_scheduler_metadata:
(num_sm_parts, TileSchedulerMetaDataSize), torch.int32,
returned by get_mla_metadata.
- num_splits:
(batch_size + 1), torch.int32, returned by get_mla_metadata.
- softmax_scale: float.
The scale of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
- causal: bool. Whether to apply causal attention mask.
- descale_q: (batch_size),
torch.float32. Descaling factors for Q, used for fp8 quantization.
- descale_k: (batch_size),
torch.float32. Descaling factors for K, used for fp8 quantization.
- is_fp8_kvcache: bool.
Whether the k_cache and v_cache are in fp8 format.
For the format of FP8 KV cache, please refer to README.md
- indices: (batch_size, seq_len_q, topk), torch.int32.
If not None, sparse attention will be enabled,
and only tokens in the `indices` array will be attended to.
Invalid indices should be set to -1 or numbers >= total_seq_len_kv.
For details about how to set up `indices`, please refer to README.md.
Returns:
- out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
- softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
"""
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
indices
is
not
None
:
# NOTE (zyongye): sparse attention is also causal
# since it only attend to the tokens before
# but here `causal` should not be specified
assert
not
causal
,
"causal must be `false` if sparse attention is enabled."
assert
(
descale_q
is
None
)
==
(
descale_k
is
None
),
(
"descale_q and descale_k should be both None or both not None"
)
if
indices
is
None
and
q
.
element_size
()
==
1
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_extension_C
.
fwd_kvcache_mla_fp8
(
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_extension_C
.
fwd_kvcache_mla_fp8
(
q
,
q
,
k_cache
,
k_cache
,
...
@@ -194,53 +150,9 @@ def flash_mla_with_kvcache(
...
@@ -194,53 +150,9 @@ def flash_mla_with_kvcache(
descale_q
,
descale_q
,
descale_k
,
descale_k
,
)
)
else
:
out
,
softmax_lse
=
torch
.
ops
.
_flashmla_C
.
fwd_kvcache_mla
(
q
,
k_cache
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
is_fp8_kvcache
,
indices
,
)
return
out
,
softmax_lse
return
out
,
softmax_lse
def
flash_mla_sparse_prefill
(
q
:
torch
.
Tensor
,
kv
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
sm_scale
:
float
,
d_v
:
int
=
512
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Sparse attention prefill kernel
Args:
- q: [s_q, h_q, d_qk], bfloat16
- kv: [s_kv, h_kv, d_qk], bfloat16
- indices: [s_q, h_kv, topk], int32.
Invalid indices should be set to -1 or numbers >= s_kv
- sm_scale: float
- d_v: The dimension of value vectors. Can only be 512
Returns:
- (output, max_logits, lse)
About the definition of output,
max_logits and lse, please refer to README.md
- output: [s_q, h_q, d_v], bfloat16
- max_logits: [s_q, h_q], float
- lse: [s_q, h_q], float, 2-based log-sum-exp
"""
results
=
torch
.
ops
.
_flashmla_C
.
sparse_prefill_fwd
(
q
,
kv
,
indices
,
sm_scale
,
d_v
)
return
results
#
#
# TODO: Add fake functions
# TODO: Add fake functions
#
#
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment