Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
41aa5784
Unverified
Commit
41aa5784
authored
Jun 04, 2025
by
Kaixi Hou
Committed by
GitHub
Jun 03, 2025
Browse files
[NVIDIA] Add Cutlass MLA backend (#17625)
parent
8d646c2e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
111 additions
and
3 deletions
+111
-3
csrc/attention/mla/cutlass_mla_kernels.cu
csrc/attention/mla/cutlass_mla_kernels.cu
+1
-1
tests/kernels/test_cutlass_mla_decode.py
tests/kernels/test_cutlass_mla_decode.py
+3
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+8
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+96
-0
No files found.
csrc/attention/mla/cutlass_mla_kernels.cu
View file @
41aa5784
...
@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
hw_info
,
hw_info
,
-
1
,
// split_kv
1
,
// split_kv
nullptr
,
// is_var_split_kv
nullptr
,
// is_var_split_kv
};
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
...
...
tests/kernels/test_cutlass_mla_decode.py
View file @
41aa5784
...
@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
...
@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
pack_factor
=
128
//
block_size
pack_factor
=
128
//
block_size
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
# Amplify input values to ensure test coverage of edge cases where CUTLASS
# kernel errors occur with split_k settings.
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
*
100
block_table
=
torch
.
randint
(
0
,
block_table
=
torch
.
randint
(
0
,
bs
*
block_num
,
(
bs
,
block_num
),
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
dtype
=
torch
.
int32
)
...
...
vllm/engine/arg_utils.py
View file @
41aa5784
...
@@ -1395,6 +1395,7 @@ class EngineArgs:
...
@@ -1395,6 +1395,7 @@ class EngineArgs:
"PALLAS_VLLM_V1"
,
"PALLAS_VLLM_V1"
,
"TRITON_ATTN_VLLM_V1"
,
"TRITON_ATTN_VLLM_V1"
,
"TRITON_MLA"
,
"TRITON_MLA"
,
"CUTLASS_MLA_VLLM_V1"
,
"FLASHMLA"
,
"FLASHMLA"
,
"FLASHINFER"
,
"FLASHINFER"
,
"FLASHINFER_VLLM_V1"
,
"FLASHINFER_VLLM_V1"
,
...
...
vllm/platforms/cuda.py
View file @
41aa5784
...
@@ -183,6 +183,14 @@ class CudaPlatformBase(Platform):
...
@@ -183,6 +183,14 @@ class CudaPlatformBase(Platform):
if
use_mla
:
if
use_mla
:
# TODO(lucas): refactor to be more concise
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
# we should probably consider factoring out V1 here
if
selected_backend
==
_Backend
.
CUTLASS_MLA_VLLM_V1
:
if
use_v1
:
logger
.
info_once
(
"Using Cutlass MLA backend on V1 engine."
)
return
(
"vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend"
)
else
:
logger
.
warning
(
"Cutlass MLA backend is only supported on V1 engine"
)
if
selected_backend
==
_Backend
.
TRITON_MLA
or
block_size
!=
64
:
if
selected_backend
==
_Backend
.
TRITON_MLA
or
block_size
!=
64
:
if
use_v1
:
if
use_v1
:
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
logger
.
info_once
(
"Using Triton MLA backend on V1 engine."
)
...
...
vllm/platforms/interface.py
View file @
41aa5784
...
@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
...
@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
TRITON_MLA_VLLM_V1
=
enum
.
auto
()
TRITON_MLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA
=
enum
.
auto
()
# Supported by V1
FLASHMLA
=
enum
.
auto
()
# Supported by V1
CUTLASS_MLA_VLLM_V1
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
HPU_ATTN
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
PALLAS_VLLM_V1
=
enum
.
auto
()
PALLAS_VLLM_V1
=
enum
.
auto
()
...
...
vllm/v1/attention/backends/mla/common.py
View file @
41aa5784
...
@@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]):
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
runner
.
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
model_config
)
self
.
mla_dims
=
get_mla_dims
(
model_config
)
self
.
aot_schedule
=
is_vllm_fa
and
(
ge
t_
f
la
sh_attn_version
()
==
3
)
self
.
aot_schedule
=
curren
t_
p
la
tform
.
is_cuda
(
)
self
.
kv_cache_spec
=
kv_cache_spec
self
.
kv_cache_spec
=
kv_cache_spec
# Dont try to access the runner on AMD
# Dont try to access the runner on AMD
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
0 → 100644
View file @
41aa5784
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Optional
import
torch
import
vllm._custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
)
logger
=
init_logger
(
__name__
)
class
CutlassMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"CUTLASS_MLA_VLLM_V1"
@
staticmethod
def
get_impl_cls
()
->
type
[
"CutlassMLAImpl"
]:
return
CutlassMLAImpl
class
CutlassMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"CutlassMLA V1 with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Cutlass MLA not yet supported"
)
B
=
q_nope
.
shape
[
0
]
o
=
torch
.
empty
((
B
,
self
.
num_heads
,
self
.
kv_lora_rank
),
dtype
=
q_nope
.
dtype
,
device
=
q_nope
.
device
)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope
=
q_nope
.
clone
()
q_pe
=
q_pe
.
clone
()
ops
.
cutlass_mla_decode
(
o
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
attn_metadata
.
decode
.
seq_lens
,
attn_metadata
.
decode
.
block_table
,
self
.
scale
)
return
self
.
_v_up_proj
(
o
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment