Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a7423220
Unverified
Commit
a7423220
authored
Sep 03, 2025
by
Matthew Bonanni
Committed by
GitHub
Sep 03, 2025
Browse files
[Attention] Blackwell FP8 MLA support with CUTLASS_MLA backend (#23289)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
731a6940
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
186 additions
and
107 deletions
+186
-107
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+8
-8
tests/kernels/test_cutlass_mla_decode.py
tests/kernels/test_cutlass_mla_decode.py
+167
-83
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-2
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+9
-14
No files found.
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
View file @
a7423220
...
@@ -64,11 +64,11 @@ struct IsPersistent {
...
@@ -64,11 +64,11 @@ struct IsPersistent {
static
const
bool
value
=
v
;
static
const
bool
value
=
v
;
};
};
template
<
typename
T
,
bool
IsPaged128
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
template
<
typename
T
,
typename
TOut
,
bool
IsPaged128
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
struct
MlaSm100
{
struct
MlaSm100
{
using
Element
=
T
;
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementAcc
=
float
;
using
ElementOut
=
T
;
using
ElementOut
=
T
Out
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
...
@@ -178,7 +178,7 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -178,7 +178,7 @@ typename T::Fmha::Arguments args_from_options(
return
arguments
;
return
arguments
;
}
}
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
template
<
typename
Element
,
typename
ElementOut
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_nope
,
...
@@ -190,7 +190,7 @@ void runMla(
...
@@ -190,7 +190,7 @@ void runMla(
double
sm_scale
,
double
sm_scale
,
int64_t
num_kv_splits
,
int64_t
num_kv_splits
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
using
MlaSm100Type
=
MlaSm100
<
Element
,
ElementOut
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
sm_scale
,
num_kv_splits
);
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
sm_scale
,
num_kv_splits
);
...
@@ -233,13 +233,13 @@ void sm100_cutlass_mla_decode(
...
@@ -233,13 +233,13 @@ void sm100_cutlass_mla_decode(
DISPATCH_BOOL
(
page_size
==
128
,
IsPaged128
,
[
&
]
{
DISPATCH_BOOL
(
page_size
==
128
,
IsPaged128
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
half_t
,
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
runMla
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
{
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
...
@@ -253,7 +253,7 @@ void sm100_cutlass_mla_decode(
...
@@ -253,7 +253,7 @@ void sm100_cutlass_mla_decode(
int64_t
sm100_cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
int64_t
sm100_cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
// which are float, so Element type here doesn't matter.
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
,
true
>
;
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
,
cutlass
::
half_t
,
true
>
;
// Get split kv. Requires problem shape and sm_count only.
// Get split kv. Requires problem shape and sm_count only.
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
...
...
tests/kernels/test_cutlass_mla_decode.py
View file @
a7423220
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
import
random
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
import
vllm._custom_ops
as
ops
import
vllm._custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
triton
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Cutlass MLA Requires compute capability of 10 or above."
,
def
cal_diff
(
x
:
torch
.
Tensor
,
allow_module_level
=
True
)
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
)
->
None
:
def
ref_mla
(
x
,
y
=
x
.
double
(),
y
.
double
()
out
:
Tensor
,
# (bs, num_heads, v_head_dim)
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
query
:
Tensor
,
# (bs, num_heads, head_dim)
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
kv_cache
:
Tensor
,
# (num_blocks, block_size, head_dim)
if
(
use_fp8
):
scale
:
float
,
assert
cos_diff
<
1e-4
block_tables
:
Tensor
,
# (bs, max_num_blocks)
else
:
seq_lens
:
Tensor
,
# (bs,)
assert
cos_diff
<
1e-5
):
bs
,
num_heads
,
v_head_dim
=
out
.
shape
head_dim
=
query
.
shape
[
2
]
CUTLASS_MLA_UNSUPPORTED_REASON
=
\
"Cutlass MLA Requires compute capability of 10 or above."
\
for
i
in
range
(
bs
):
if
not
current_platform
.
is_device_capability
(
100
)
\
# gather and flatten KV-cache
else
"Cutlass MLA is supported"
kv
=
kv_cache
[
block_tables
[
i
]]
# (max_num_blocks, block_size, head_dim)
kv
=
kv
.
view
(
1
,
-
1
,
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
100
),
head_dim
)[:,
:
seq_lens
[
i
]]
# (1, seq_len, head_dim)
reason
=
CUTLASS_MLA_UNSUPPORTED_REASON
)
v
=
kv
[:,
:,
:
v_head_dim
]
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
])
q
=
query
[
i
].
view
(
num_heads
,
1
,
head_dim
)
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
,
16384
])
o
=
F
.
scaled_dot_product_attention
(
q
,
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
kv
,
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
v
,
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
scale
=
scale
,
@
pytest
.
mark
.
parametrize
(
"dv"
,
[
512
])
enable_gqa
=
True
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
out
[
i
]
=
o
.
view
(
num_heads
,
v_head_dim
)
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
return
out
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"mean_seq_len"
,
[
128
,
1024
,
4096
])
@
pytest
.
mark
.
parametrize
(
"bs"
,
[
1
,
2
,
4
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"torch_dtype"
,
[
torch
.
bfloat16
,
torch
.
float8_e4m3fn
])
def
test_cutlass_mla_decode
(
dtype
:
torch
.
dtype
,
mean_seq_len
:
int
,
bs
:
int
,
@
torch
.
inference_mode
()
varlen
:
bool
,
block_size
:
int
):
def
test_cutlass_mla_decode
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
torch
.
set_default_dtype
(
dtype
)
causal
,
varlen
,
torch_dtype
):
torch
.
set_default_device
(
'cuda'
)
device
=
torch
.
device
(
"cuda:0"
)
if
torch_dtype
==
torch
.
float8_e4m3fn
:
init_dtype
=
torch
.
bfloat16
else
:
init_dtype
=
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
42
)
torch
.
manual_seed
(
42
)
random
.
seed
(
42
)
d
=
576
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
h_q
=
128
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
torch_dtype
=
}
"
)
dv
=
512
q_nope_dim
=
128
use_fp8
=
torch_dtype
==
torch
.
float8_e4m3fn
q_pe_dim
=
64
scale
=
math
.
sqrt
(
d
)
**
(
-
1
)
s
ca
le
=
(
q_nope_dim
+
q_pe_dim
)
**
(
-
0.5
)
ca
che_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
if
varlen
:
seq_lens
=
torch
.
empty
(
bs
).
normal_
(
mean_seq_len
,
mean_seq_len
/
2
)
for
i
in
range
(
b
):
seq_lens
=
seq_lens
.
clip
(
2
).
to
(
torch
.
int32
)
cache_seqlens
[
i
]
=
max
(
random
.
normalvariate
(
mean_sk
,
mean_sk
/
2
),
s_q
)
total_seqlens
=
cache_seqlens
.
sum
().
item
()
max_seqlen
=
cache_seqlens
.
max
().
item
()
max_seqlen_pad
=
triton
.
cdiv
(
max_seqlen
,
256
)
*
256
q
=
torch
.
randn
(
b
,
s_q
,
h_q
,
d
)
block_table
=
torch
.
arange
(
b
*
max_seqlen_pad
//
block_size
,
dtype
=
torch
.
int32
).
view
(
b
,
max_seqlen_pad
//
block_size
)
blocked_k
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
h_kv
,
d
)
blocked_v
=
blocked_k
[...,
:
dv
]
init_dtype
=
q
.
dtype
if
use_fp8
:
fp8_dtype
=
torch
.
float8_e4m3fn
descale_q
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
descale_k
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
q
=
q
.
to
(
fp8_dtype
)
blocked_k
=
blocked_k
.
to
(
fp8_dtype
)
blocked_v
=
blocked_v
.
to
(
fp8_dtype
)
else
:
else
:
seq_lens
=
torch
.
full
((
bs
,
),
mean_seq_len
,
dtype
=
torch
.
int32
)
descale_q
=
None
max_seq_len
=
seq_lens
.
max
().
item
()
descale_k
=
None
block_num
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
def
cutlass_mla
():
# Pad block_num so that small blocks can be packed into full 128-sized
MAX_HEADS
=
128
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
# blocks.
q_reshaped
=
q
.
squeeze
(
1
)
pack_factor
=
128
//
block_size
q_nope
=
q_reshaped
[:,
:,
:
dv
].
clone
()
block_num
=
((
block_num
+
pack_factor
-
1
)
//
pack_factor
)
*
pack_factor
q_pe
=
q_reshaped
[:,
:,
dv
:].
clone
()
# Amplify input values to ensure test coverage of edge cases where CUTLASS
if
h_q
<
MAX_HEADS
:
# kernel errors occur with split_k settings.
q_nope_padded
=
q_nope
.
new_empty
((
b
,
MAX_HEADS
,
dv
))
q
=
torch
.
randn
(
bs
,
h_q
,
d
)
*
100
q_nope_padded
[:,
:
h_q
]
=
q_nope
block_table
=
torch
.
randint
(
0
,
q_nope
=
q_nope_padded
bs
*
block_num
,
(
bs
,
block_num
),
dtype
=
torch
.
int32
)
q_pe_padded
=
q_pe
.
new_empty
((
b
,
MAX_HEADS
,
d
-
dv
))
q_pe_padded
[:,
:
h_q
]
=
q_pe
kv_cache
=
torch
.
randn
(
block_table
.
numel
(),
block_size
,
d
)
q_pe
=
q_pe_padded
out_ref
=
q
.
new_zeros
(
bs
,
h_q
,
dv
)
kv_cache_flat
=
blocked_k
.
squeeze
(
2
)
ref_mla
(
out_ref
,
q
,
kv_cache
,
scale
,
block_table
,
seq_lens
)
device_properties
=
torch
.
cuda
.
get_device_properties
(
out_ans
=
torch
.
zeros_like
(
out_ref
)
torch
.
device
(
"cuda:0"
))
q_nope
=
q
[:,
:,
:
dv
].
clone
()
sm_count
=
device_properties
.
multi_processor_count
q_pe
=
q
[:,
:,
dv
:].
clone
()
workspace_size
=
ops
.
sm100_cutlass_mla_get_workspace_size
(
ops
.
cutlass_mla_decode
(
out_ans
,
q_nope
,
q_pe
,
kv_cache
,
seq_lens
,
max_seqlen
*
block_size
,
b
,
sm_count
,
num_kv_splits
=
1
)
block_table
,
scale
)
workspace
=
torch
.
empty
(
workspace_size
,
device
=
"cuda"
,
torch
.
testing
.
assert_close
(
out_ans
,
out_ref
,
atol
=
1e-2
,
rtol
=
1e-2
)
dtype
=
torch
.
uint8
)
out_ans
=
torch
.
empty
(
b
,
MAX_HEADS
,
dv
,
dtype
=
init_dtype
)
ops
.
sm100_cutlass_mla_decode
(
out_ans
,
q_nope
,
q_pe
,
kv_cache_flat
,
cache_seqlens
,
block_table
,
workspace
,
scale
,
1
)
return
out_ans
[:,
:
h_q
].
contiguous
()
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
query
=
query
.
float
()
key
=
key
.
float
()
value
=
value
.
float
()
key
=
key
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
value
=
value
.
repeat_interleave
(
h_q
//
h_kv
,
dim
=
0
)
attn_weight
=
query
@
key
.
transpose
(
-
2
,
-
1
)
/
math
.
sqrt
(
query
.
size
(
-
1
))
if
is_causal
:
s_q
=
query
.
shape
[
-
2
]
s_k
=
key
.
shape
[
-
2
]
attn_bias
=
torch
.
zeros
(
s_q
,
s_k
,
dtype
=
query
.
dtype
)
temp_mask
=
torch
.
ones
(
s_q
,
s_k
,
dtype
=
torch
.
bool
).
tril
(
diagonal
=
s_k
-
s_q
)
attn_bias
.
masked_fill_
(
temp_mask
.
logical_not
(),
float
(
"-inf"
))
attn_bias
.
to
(
query
.
dtype
)
attn_weight
+=
attn_bias
lse
=
attn_weight
.
logsumexp
(
dim
=-
1
)
attn_weight
=
torch
.
softmax
(
attn_weight
,
dim
=-
1
,
dtype
=
torch
.
float32
)
return
attn_weight
@
value
,
lse
def
ref_mla
():
q_
=
(
q
.
to
(
torch
.
float
)
*
descale_q
).
to
(
init_dtype
)
if
use_fp8
else
q
blocked_k_
=
(
blocked_k
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_k
blocked_v_
=
(
blocked_v
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_v
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
out_i
,
lse_i
=
scaled_dot_product_attention
(
q_
[
i
].
transpose
(
0
,
1
),
blocked_k_
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v_
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
)
out
[
i
]
=
out_i
.
transpose
(
0
,
1
)
lse
[
i
]
=
lse_i
return
out
,
lse
out_cutlass
=
cutlass_mla
()
out_torch
,
lse_torch
=
ref_mla
()
# Extract the single token (s_q=1) slice to match cutlass output shape
out_torch_slice
=
out_torch
[:,
0
,
:,
:]
# [b, h_q, dv]
cal_diff
(
out_cutlass
,
out_torch_slice
,
"out"
,
use_fp8
)
t
=
triton
.
testing
.
do_bench
(
cutlass_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,"
,
f
"
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
vllm/platforms/cuda.py
View file @
a7423220
...
@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
...
@@ -500,8 +500,8 @@ class CudaPlatformBase(Platform):
else
:
else
:
attention_backend
=
"FLASHMLA"
attention_backend
=
"FLASHMLA"
# Only FlashMLA support
s
fp8
# Only FlashMLA
and CUTLASS_MLA
support fp8
if
attention_backend
==
"FLASHMLA"
:
if
attention_backend
in
[
"FLASHMLA"
,
"CUTLASS_MLA"
]
:
supported
=
True
supported
=
True
else
:
else
:
supported
=
(
not
fp8_attention
)
supported
=
(
not
fp8_attention
)
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
a7423220
...
@@ -108,10 +108,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -108,10 +108,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"are not implemented for "
"CutlassMLAImpl"
)
"CutlassMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"CutlassMLA V1 with FP8 KV cache not yet supported"
)
self
.
_use_old_cutlass_mla
=
False
self
.
_use_old_cutlass_mla
=
False
force_old_cutlass
=
os
.
environ
.
get
(
"FORCE_OLD_CUTLASS_MLA"
,
None
)
force_old_cutlass
=
os
.
environ
.
get
(
"FORCE_OLD_CUTLASS_MLA"
,
None
)
if
force_old_cutlass
:
if
force_old_cutlass
:
...
@@ -182,11 +178,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -182,11 +178,10 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
>
0
),
f
"block num must be greater than 0, got
{
block_num
}
"
>
0
),
f
"block num must be greater than 0, got
{
block_num
}
"
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
assert
block_num
%
(
128
/
PAGE_SIZE
)
==
0
# TODO(kaixih@nvidia): support fp8
assert
q_nope
.
dtype
in
(
assert
q_nope
.
dtype
in
(
torch
.
float16
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float8_e4m3fn
),
(
torch
.
bfloat16
,
f
"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
),
f
"q_nope.dtype needs to be fp16 or bf16 but got
{
q_nope
.
dtype
}
."
f
"
{
q_nope
.
dtype
}
."
)
assert
q_nope
.
dtype
==
q_pe
.
dtype
==
kv_c_and_k_pe_cache
.
dtype
assert
q_nope
.
dtype
==
q_pe
.
dtype
==
kv_c_and_k_pe_cache
.
dtype
assert
(
assert
(
seq_lens
.
dtype
==
torch
.
int32
seq_lens
.
dtype
==
torch
.
int32
...
@@ -195,7 +190,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -195,7 +190,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
page_table
.
dtype
==
torch
.
int32
page_table
.
dtype
==
torch
.
int32
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
),
f
"page_table.dtype needs to be int32 but got
{
page_table
.
dtype
}
."
out
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
))
dtype
=
(
torch
.
bfloat16
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
else
q_nope
.
dtype
)
out
=
q_nope
.
new_empty
((
B_q
,
MAX_HEADS
,
D_latent
),
dtype
=
dtype
)
ops
.
sm100_cutlass_mla_decode
(
ops
.
sm100_cutlass_mla_decode
(
out
,
out
,
...
@@ -220,9 +217,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -220,9 +217,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Cutlass MLA not yet supported"
)
# Adjust workspace size (if necessary)
# Adjust workspace size (if necessary)
self
.
_workspace
.
ensure_size
(
attn_metadata
,
self
.
_num_kv_splits
)
self
.
_workspace
.
ensure_size
(
attn_metadata
,
self
.
_num_kv_splits
)
...
@@ -252,8 +246,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -252,8 +246,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 Cutlass MLA not yet supported"
)
raise
NotImplementedError
(
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA"
)
B
=
q_nope
.
shape
[
0
]
B
=
q_nope
.
shape
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment