Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
19fe1a05
Unverified
Commit
19fe1a05
authored
Aug 21, 2025
by
Matthew Bonanni
Committed by
GitHub
Aug 22, 2025
Browse files
[Kernel] Add FP8 support with FlashMLA backend (#22668)
Signed-off-by:
Matthew Bonanni
<
mbonanni001@gmail.com
>
parent
480bdf5a
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
235 additions
and
109 deletions
+235
-109
cmake/external_projects/flashmla.cmake
cmake/external_projects/flashmla.cmake
+5
-4
csrc/cache.h
csrc/cache.h
+4
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+29
-28
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+9
-4
tests/kernels/attention/test_cache.py
tests/kernels/attention/test_cache.py
+29
-13
tests/kernels/attention/test_flashmla.py
tests/kernels/attention/test_flashmla.py
+51
-18
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-8
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+9
-5
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+33
-8
vllm/platforms/interface.py
vllm/platforms/interface.py
+2
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+2
-1
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+31
-6
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+2
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+4
-6
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
+2
-0
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+2
-1
No files found.
cmake/external_projects/flashmla.cmake
View file @
19fe1a05
...
...
@@ -19,7 +19,7 @@ else()
FetchContent_Declare
(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG
0e43e774597682284358ff2c54530757b654b8d1
GIT_TAG
a757314c04eedd166e329e846c820eb1bdd702de
GIT_PROGRESS TRUE
CONFIGURE_COMMAND
""
BUILD_COMMAND
""
...
...
@@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS
)
set
(
FlashMLA_SOURCES
${
flashmla_SOURCE_DIR
}
/csrc/flash_api.cpp
${
flashmla_SOURCE_DIR
}
/csrc/kernels/
splitkv_ml
a.cu
${
flashmla_SOURCE_DIR
}
/csrc/kernels/
get_mla_metadat
a.cu
${
flashmla_SOURCE_DIR
}
/csrc/kernels/mla_combine.cu
${
flashmla_SOURCE_DIR
}
/csrc/kernels/get_mla_metadata.cu
)
${
flashmla_SOURCE_DIR
}
/csrc/kernels/splitkv_mla.cu
${
flashmla_SOURCE_DIR
}
/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu
)
set
(
FlashMLA_INCLUDES
${
flashmla_SOURCE_DIR
}
/csrc/cutlass/include
${
flashmla_SOURCE_DIR
}
/csrc
/include
)
${
flashmla_SOURCE_DIR
}
/csrc
)
set_gencode_flags_for_srcs
(
SRCS
"
${
FlashMLA_SOURCES
}
"
...
...
csrc/cache.h
View file @
19fe1a05
...
...
@@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
void
convert_fp8
(
torch
::
Tensor
&
dst_cache
,
torch
::
Tensor
&
src_cache
,
const
double
scale
,
const
std
::
string
&
kv_cache_dtype
);
void
gather_cache
(
void
gather_
and_maybe_dequant_
cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, ENTRIES...]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
cu_seq_lens
,
// [BATCH+1]
int64_t
batch_size
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
\ No newline at end of file
int64_t
batch_size
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
const
&
scale
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
);
\ No newline at end of file
csrc/cache_kernels.cu
View file @
19fe1a05
...
...
@@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace
vllm
{
// grid is launched with dimensions (batch, num_splits)
template
<
typename
scalar_t
>
__global__
void
gather_cache
(
const
s
ca
lar
_t
*
__restrict__
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE,
template
<
typename
scalar_t
,
typename
cache_t
,
Fp8KVCacheDataType
kv_dt
>
__global__
void
gather_
and_maybe_dequant_
cache
(
const
ca
che
_t
*
__restrict__
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t
*
__restrict__
dst
,
// [TOT_TOKENS, ENTRIES...]
const
int32_t
*
__restrict__
block_table
,
// [BATCH, BLOCK_INDICES]
...
...
@@ -634,6 +634,7 @@ __global__ void gather_cache(
const
int32_t
block_size
,
const
int32_t
entry_size
,
const
int64_t
block_table_stride
,
const
int64_t
cache_block_stride
,
const
int64_t
cache_entry_stride
,
const
int64_t
dst_entry_stride
,
const
float
*
__restrict__
scale
,
const
int32_t
*
__restrict__
seq_starts
)
{
// Optional: starting offsets per
// batch
...
...
@@ -675,10 +676,16 @@ __global__ void gather_cache(
if
(
partial_block_size
)
full_blocks_end
-=
1
;
}
auto
copy_entry
=
[
&
](
const
s
ca
lar
_t
*
__restrict__
_src
,
auto
copy_entry
=
[
&
](
const
ca
che
_t
*
__restrict__
_src
,
scalar_t
*
__restrict__
_dst
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
entry_size
;
i
+=
blockDim
.
x
)
_dst
[
i
]
=
_src
[
i
];
for
(
int
i
=
threadIdx
.
x
;
i
<
entry_size
;
i
+=
blockDim
.
x
)
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
_dst
[
i
]
=
static_cast
<
scalar_t
>
(
_src
[
i
]);
}
else
{
_dst
[
i
]
=
fp8
::
scaled_convert
<
scalar_t
,
cache_t
,
kv_dt
>
(
_src
[
i
],
*
scale
);
}
}
};
for
(
int
pid
=
split_start
;
pid
<
full_blocks_end
;
++
pid
)
{
...
...
@@ -705,25 +712,31 @@ __global__ void gather_cache(
}
// namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, \
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void
gather_cache
(
void
gather_
and_maybe_dequant_
cache
(
torch
::
Tensor
const
&
src_cache
,
// [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch
::
Tensor
const
&
dst
,
// [TOT_TOKENS, ENTRIES...]
torch
::
Tensor
const
&
block_table
,
// [BATCH, BLOCK_INDICES]
torch
::
Tensor
const
&
cu_seq_lens
,
// [BATCH+1]
int64_t
batch_size
,
int64_t
batch_size
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
const
&
scale
,
std
::
optional
<
torch
::
Tensor
>
seq_starts
=
std
::
nullopt
)
{
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_cache
.
device
());
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -761,20 +774,8 @@ void gather_cache(
dim3
grid
(
batch_size
,
num_splits
);
dim3
block
(
1024
);
TORCH_CHECK
(
src_cache
.
dtype
()
==
dst
.
dtype
(),
"src_cache and dst must have the same dtype"
);
const
int
dtype_bits
=
src_cache
.
element_size
()
*
8
;
const
int32_t
*
seq_starts_ptr
=
seq_starts
.
has_value
()
?
seq_starts
.
value
().
data_ptr
<
int32_t
>
()
:
nullptr
;
if
(
dtype_bits
==
32
)
{
CALL_GATHER_CACHE
(
uint32_t
);
}
else
if
(
dtype_bits
==
16
)
{
CALL_GATHER_CACHE
(
uint16_t
);
}
else
if
(
dtype_bits
==
8
)
{
CALL_GATHER_CACHE
(
uint8_t
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported data type width: "
,
dtype_bits
);
}
DISPATCH_BY_KV_CACHE_DTYPE
(
dst
.
dtype
(),
kv_cache_dtype
,
CALL_GATHER_CACHE
);
}
csrc/torch_bindings.cpp
View file @
19fe1a05
...
...
@@ -672,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"str kv_cache_dtype) -> ()"
);
cache_ops
.
impl
(
"convert_fp8"
,
torch
::
kCUDA
,
&
convert_fp8
);
// Gather cache blocks from src_cache to dst.
// Gather cache blocks from src_cache to dst, dequantizing from
// src_cache's dtype to dst's dtype if necessary.
cache_ops
.
def
(
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"
);
cache_ops
.
impl
(
"gather_cache"
,
torch
::
kCUDA
,
&
gather_cache
);
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
" Tensor block_table, Tensor cu_seq_lens, "
" int batch_size, "
" str kv_cache_dtype, "
" Tensor scale, Tensor? seq_starts) -> ()"
);
cache_ops
.
impl
(
"gather_and_maybe_dequant_cache"
,
torch
::
kCUDA
,
&
gather_and_maybe_dequant_cache
);
}
TORCH_LIBRARY_EXPAND
(
CONCAT
(
TORCH_EXTENSION_NAME
,
_cuda_utils
),
cuda_utils
)
{
...
...
tests/kernels/attention/test_cache.py
View file @
19fe1a05
...
...
@@ -709,14 +709,15 @@ def test_swap_blocks_mla(
@
pytest
.
mark
.
parametrize
(
"max_seq_len"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float32
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
])
# You can also test "fp8" if needed.
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_gather_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
def
test_gather_and_maybe_dequant_cache_mla
(
kv_lora_rank
,
qk_rope_head_dim
,
block_size
,
num_blocks
,
max_seq_len
,
batch_size
,
dtype
,
kv_cache_dtype
,
device
):
entry_size
=
kv_lora_rank
+
qk_rope_head_dim
scale
=
torch
.
tensor
(
0.1
,
dtype
=
torch
.
float32
,
device
=
device
)
src_cache
=
_create_mla_cache
(
num_blocks
,
block_size
,
entry_size
,
dtype
,
kv_cache_dtype
,
device
)
_fill_mla_cache
(
src_cache
,
kv_cache_dtype
=
kv_cache_dtype
)
...
...
@@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
perm
=
torch
.
randperm
(
num_blocks
,
device
=
device
)
block_table
[
b
,
:]
=
perm
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
src_cache
.
dtype
,
device
=
device
)
dst
=
torch
.
zeros
((
total_tokens
,
entry_size
),
dtype
=
dtype
,
device
=
device
)
expected_batches
=
[]
for
b
in
range
(
batch_size
):
...
...
@@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
gathered_rows
=
[]
for
i
in
range
(
tot
-
1
):
gathered_rows
.
append
(
src_cache
[
blocks
[
i
]])
block_data
=
src_cache
[
blocks
[
i
]]
if
kv_cache_dtype
==
"fp8"
:
dequantized_block
=
torch
.
empty_like
(
block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_block
,
block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_block
)
else
:
gathered_rows
.
append
(
block_data
)
remaining
=
s
-
(
tot
-
1
)
*
block_size
gathered_rows
.
append
(
src_cache
[
blocks
[
-
1
],
:
remaining
,
:])
last_block_data
=
src_cache
[
blocks
[
-
1
],
:
remaining
,
:]
if
kv_cache_dtype
==
"fp8"
:
dequantized_last_block
=
torch
.
empty_like
(
last_block_data
,
dtype
=
dtype
)
ops
.
convert_fp8
(
dequantized_last_block
,
last_block_data
,
scale
.
item
())
gathered_rows
.
append
(
dequantized_last_block
)
else
:
gathered_rows
.
append
(
last_block_data
)
batch_expected
=
torch
.
cat
(
gathered_rows
,
dim
=
0
)
expected_batches
.
append
(
batch_expected
)
expected
=
torch
.
cat
(
expected_batches
,
dim
=
0
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
gather_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
None
),
torch
.
ops
.
_C_cache_ops
.
gather_and_maybe_dequant_cache
,
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
)
ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
)
ops
.
gather_and_maybe_dequant_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
None
)
torch
.
testing
.
assert_close
(
dst
,
expected
)
...
...
tests/kernels/attention/test_flashmla.py
View file @
19fe1a05
...
...
@@ -13,11 +13,17 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
from
vllm.triton_utils
import
triton
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
)
->
None
:
def
cal_diff
(
x
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
name
:
str
,
use_fp8
:
bool
=
False
)
->
None
:
x
,
y
=
x
.
double
(),
y
.
double
()
cos_diff
=
1
-
2
*
(
x
*
y
).
sum
().
item
()
/
max
(
(
x
*
x
+
y
*
y
).
sum
().
item
(),
1e-12
)
assert
cos_diff
<
1e-5
if
(
use_fp8
):
assert
cos_diff
<
1e-4
else
:
assert
cos_diff
<
1e-5
FLASH_MLA_UNSUPPORTED_REASON
=
is_flashmla_supported
()[
1
]
\
if
not
is_flashmla_supported
()[
0
]
else
"FlashMLA is supported"
...
...
@@ -27,7 +33,7 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
reason
=
FLASH_MLA_UNSUPPORTED_REASON
)
@
pytest
.
mark
.
parametrize
(
"b"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"s_q"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"mean_sk"
,
[
4096
,
8192
,
16384
])
@
pytest
.
mark
.
parametrize
(
"h_q"
,
[
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"h_kv"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
576
])
...
...
@@ -35,20 +41,26 @@ FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"varlen"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"torch_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
,
torch
.
float8_e4m3fn
])
@
torch
.
inference_mode
()
def
test_flash_mla
(
b
,
s_q
,
mean_sk
,
h_q
,
h_kv
,
d
,
dv
,
block_size
,
causal
,
varlen
,
dtype
):
varlen
,
torch_
dtype
):
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
if
torch_dtype
==
torch
.
float8_e4m3fn
:
init_dtype
=
torch
.
bfloat16
else
:
init_dtype
=
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
torch
.
set_default_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
torch
.
manual_seed
(
0
)
random
.
seed
(
0
)
print
(
f
"
{
b
=
}
,
{
s_q
=
}
,
{
mean_sk
=
}
,
{
h_q
=
}
,
{
h_kv
=
}
, "
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
dtype
=
}
"
)
f
"
{
d
=
}
,
{
dv
=
}
,
{
causal
=
}
,
{
varlen
=
}
,
{
torch_
dtype
=
}
"
)
use_fp8
=
torch_dtype
==
torch
.
float8_e4m3fn
cache_seqlens
=
torch
.
full
((
b
,
),
mean_sk
,
dtype
=
torch
.
int32
)
if
varlen
:
for
i
in
range
(
b
):
...
...
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
num_splits
=
get_mla_metadata
(
cache_seqlens
,
s_q
*
h_q
//
h_kv
,
h_kv
)
init_dtype
=
q
.
dtype
if
use_fp8
:
fp8_dtype
=
torch
.
float8_e4m3fn
descale_q
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
descale_k
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
q
=
q
.
to
(
fp8_dtype
)
blocked_k
=
blocked_k
.
to
(
fp8_dtype
)
blocked_v
=
blocked_v
.
to
(
fp8_dtype
)
else
:
descale_q
=
None
descale_k
=
None
def
flash_mla
():
return
flash_mla_with_kvcache
(
q
,
...
...
@@ -81,6 +106,8 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
tile_scheduler_metadata
,
num_splits
,
causal
=
causal
,
descale_q
=
descale_q
,
descale_k
=
descale_k
,
)
def
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
False
):
...
...
@@ -104,29 +131,35 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
return
attn_weight
@
value
,
lse
def
ref_mla
():
q_
=
(
q
.
to
(
torch
.
float
)
*
descale_q
).
to
(
init_dtype
)
if
use_fp8
else
q
blocked_k_
=
(
blocked_k
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_k
blocked_v_
=
(
blocked_v
.
to
(
torch
.
float
)
*
descale_k
).
to
(
init_dtype
)
if
use_fp8
else
blocked_v
out
=
torch
.
empty
(
b
,
s_q
,
h_q
,
dv
,
dtype
=
torch
.
float32
)
lse
=
torch
.
empty
(
b
,
h_q
,
s_q
,
dtype
=
torch
.
float32
)
for
i
in
range
(
b
):
begin
=
i
*
max_seqlen_pad
end
=
begin
+
cache_seqlens
[
i
]
ref_O
,
LSE
=
scaled_dot_product_attention
(
q
[
i
].
transpose
(
0
,
1
),
blocked_k
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
out_i
,
lse_i
=
scaled_dot_product_attention
(
q
_
[
i
].
transpose
(
0
,
1
),
blocked_k
_
.
view
(
-
1
,
h_kv
,
d
)[
begin
:
end
].
transpose
(
0
,
1
),
blocked_v
_
.
view
(
-
1
,
h_kv
,
dv
)[
begin
:
end
].
transpose
(
0
,
1
),
is_causal
=
causal
,
)
out
[
i
]
=
ref_O
.
transpose
(
0
,
1
)
lse
[
i
]
=
LSE
out
[
i
]
=
out_i
.
transpose
(
0
,
1
)
lse
[
i
]
=
lse_i
return
out
,
lse
out_flash
,
lse_flash
=
flash_mla
()
out_torch
,
lse_torch
=
ref_mla
()
cal_diff
(
out_flash
,
out_torch
,
"out"
)
cal_diff
(
out_flash
,
out_torch
,
"out"
,
use_fp8
)
cal_diff
(
lse_flash
,
lse_torch
,
"lse"
)
t
=
triton
.
testing
.
do_bench
(
flash_mla
)
FLOPS
=
s_q
*
total_seqlens
*
h_q
*
(
d
+
dv
)
*
2
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
+
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
"
f
"TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
bytes
=
(
total_seqlens
*
h_kv
*
d
+
b
*
s_q
*
h_q
*
d
)
*
(
torch
.
finfo
(
torch_dtype
).
bits
//
8
)
+
(
b
*
s_q
*
h_q
*
dv
)
*
(
torch
.
finfo
(
init_dtype
).
bits
//
8
)
print
(
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,"
,
f
"
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
vllm/_custom_ops.py
View file @
19fe1a05
...
...
@@ -1589,14 +1589,18 @@ def convert_fp8(output: torch.Tensor,
torch
.
ops
.
_C_cache_ops
.
convert_fp8
(
output
,
input
,
scale
,
kv_dtype
)
def
gather_cache
(
src_cache
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
batch_size
:
int
,
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
def
gather_and_maybe_dequant_cache
(
src_cache
:
torch
.
Tensor
,
dst
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cu_seq_lens
:
torch
.
Tensor
,
batch_size
:
int
,
kv_cache_dtype
:
str
,
scale
:
torch
.
Tensor
,
seq_starts
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
torch
.
ops
.
_C_cache_ops
.
gather_and_maybe_dequant_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
kv_cache_dtype
,
scale
,
seq_starts
)
def
get_device_attribute
(
attribute
:
int
,
device
:
int
)
->
int
:
...
...
vllm/attention/backends/mla/common.py
View file @
19fe1a05
...
...
@@ -837,8 +837,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
self
.
context_chunk_workspace_size
//
num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot
handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_
and_maybe_dequant_
cache` kernel cannot
#
handle
`context_chunk_starts` that are not aligned to page_size
max_context_chunk
=
round_down
(
max_context_chunk
,
self
.
page_size
)
assert
max_context_chunk
>
0
num_chunks
=
cdiv
(
context_lens_tensor
.
max
(),
max_context_chunk
)
...
...
@@ -1082,6 +1082,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
):
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
...
...
@@ -1103,12 +1104,14 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
for
i
in
range
(
iters
):
toks
=
prefill_metadata
.
context_chunk_seq_tot
[
i
]
ops
.
gather_cache
(
ops
.
gather_
and_maybe_dequant_
cache
(
src_cache
=
kv_c_and_k_pe_cache
,
dst
=
workspace
,
block_table
=
prefill_metadata
.
block_tables
,
cu_seq_lens
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
batch_size
=
prefill_metadata
.
num_prefills
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
k_scale
,
seq_starts
=
prefill_metadata
.
context_chunk_starts
[
i
],
)
...
...
@@ -1165,6 +1168,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
...
...
@@ -1197,7 +1201,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
...
...
@@ -1287,7 +1291,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if
has_prefill
:
output
[:
num_prefill_tokens
]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
layer
.
_k_scale
)
if
has_decode
:
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
...
...
vllm/attention/ops/flashmla.py
View file @
19fe1a05
...
...
@@ -67,6 +67,8 @@ def flash_mla_with_kvcache(
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
descale_q
:
Optional
[
torch
.
Tensor
]
=
None
,
descale_k
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Arguments:
...
...
@@ -81,6 +83,8 @@ def flash_mla_with_kvcache(
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(head_dim).
causal: bool. Whether to apply causal attention mask.
descale_q: (batch_size), torch.float32. Descaling factors for Q.
descale_k: (batch_size), torch.float32. Descaling factors for K.
Return:
out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
...
...
@@ -98,6 +102,8 @@ def flash_mla_with_kvcache(
causal
,
tile_scheduler_metadata
,
num_splits
,
descale_q
,
descale_k
,
)
return
out
,
softmax_lse
...
...
vllm/engine/arg_utils.py
View file @
19fe1a05
...
...
@@ -1445,10 +1445,9 @@ class EngineArgs:
recommend_to_remove
=
False
)
return
False
# No Fp8 KV cache so far.
if
self
.
kv_cache_dtype
!=
"auto"
:
supported
=
current_platform
.
is_kv_cache_dtype_supported
(
self
.
kv_cache_dtype
)
self
.
kv_cache_dtype
,
model_config
)
if
not
supported
:
_raise_or_fallback
(
feature_name
=
"--kv-cache-dtype"
,
recommend_to_remove
=
False
)
...
...
vllm/platforms/cuda.py
View file @
19fe1a05
...
...
@@ -481,16 +481,41 @@ class CudaPlatformBase(Platform):
return
cuda_device_count_stateless
()
@
classmethod
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
)
->
bool
:
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
will_use_fa
=
(
not
envs
.
is_set
(
"
VLLM_ATTENTION_BACKEND
"
)
)
or
envs
.
VLLM_ATTENTION_BACKEND
==
"FLASH_ATTN_VLLM_V1"
attention_backend
=
envs
.
VLLM_ATTENTION_BACKEND
supported
=
False
if
cls
.
is_device_capability
(
100
):
supported
=
True
elif
fp8_attention
and
will_use_fa
:
from
vllm.attention.utils.fa_utils
import
flash_attn_supports_fp8
supported
=
flash_attn_supports_fp8
()
if
model_config
is
not
None
and
model_config
.
use_mla
:
# Default to CutlassMLA for blackwell,
# FlashMLA otherwise
if
attention_backend
is
None
:
if
cls
.
is_device_capability
(
100
):
attention_backend
=
"CUTLASS_MLA"
else
:
attention_backend
=
"FLASHMLA"
# Only FlashMLA supports fp8
if
attention_backend
==
"FLASHMLA"
:
supported
=
True
else
:
supported
=
(
not
fp8_attention
)
else
:
# Default to FlashAttention
if
attention_backend
is
None
:
attention_backend
=
"FLASH_ATTN_VLLM_V1"
# All Blackwell backends support fp8
if
cls
.
is_device_capability
(
100
):
supported
=
True
elif
attention_backend
==
"FLASH_ATTN_VLLM_V1"
:
if
fp8_attention
:
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
)
supported
=
flash_attn_supports_fp8
()
else
:
supported
=
True
return
supported
...
...
vllm/platforms/interface.py
View file @
19fe1a05
...
...
@@ -565,7 +565,8 @@ class Platform:
raise
RuntimeError
(
f
"Unsupported torch distributed backend:
{
backend
}
"
)
@
classmethod
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
)
->
bool
:
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
"""
Returns if the kv_cache_dtype is supported by the current platform.
"""
...
...
vllm/platforms/rocm.py
View file @
19fe1a05
...
...
@@ -459,5 +459,6 @@ class RocmPlatform(Platform):
return
cuda_device_count_stateless
()
@
classmethod
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
)
->
bool
:
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
return
True
vllm/platforms/tpu.py
View file @
19fe1a05
...
...
@@ -196,7 +196,8 @@ class TpuPlatform(Platform):
raise
ValueError
(
"Torch XLA does not support per-request seed."
)
@
classmethod
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
)
->
bool
:
def
is_kv_cache_dtype_supported
(
cls
,
kv_cache_dtype
:
str
,
model_config
:
"ModelConfig"
)
->
bool
:
return
True
...
...
vllm/v1/attention/backends/mla/common.py
View file @
19fe1a05
...
...
@@ -631,8 +631,9 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
if
self
.
aot_schedule
:
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
# currently the `gather_and_maybe_dequant_cache` kernel
# cannot handle `context_chunk_starts` that are not aligned
# to page_size
max_context_chunk
=
round_down
(
max_context_chunk
,
self
.
page_size
)
...
...
@@ -1005,6 +1006,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
):
assert
attn_metadata
.
prefill
is
not
None
prefill_metadata
=
attn_metadata
.
prefill
...
...
@@ -1017,12 +1019,14 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
for
i
in
range
(
iters
):
toks
=
prefill_metadata
.
chunked_context
.
seq_tot
[
i
]
ops
.
gather_cache
(
ops
.
gather_
and_maybe_dequant_
cache
(
src_cache
=
kv_c_and_k_pe_cache
,
dst
=
workspace
,
block_table
=
prefill_metadata
.
block_table
,
cu_seq_lens
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
batch_size
=
attn_metadata
.
num_prefills
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
k_scale
,
seq_starts
=
prefill_metadata
.
chunked_context
.
starts
[
i
],
)
...
...
@@ -1073,6 +1077,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
assert
attn_metadata
.
prefill
is
not
None
...
...
@@ -1095,7 +1100,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if
has_context
:
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
...
...
@@ -1119,6 +1124,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
M
,
layer
:
AttentionLayer
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -1146,6 +1152,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# same expert outputs.
return
output
.
fill_
(
0
)
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
...
...
@@ -1180,10 +1188,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
scale
=
layer
.
_k_scale
,
)
if
fp8_attention
:
kv_cache
=
kv_cache
.
view
(
current_platform
.
fp8_dtype
())
if
has_prefill
:
output
[
num_decode_tokens
:]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
)
attn_metadata
,
layer
.
_k_scale
)
if
has_decode
:
assert
attn_metadata
.
decode
is
not
None
...
...
@@ -1196,7 +1207,21 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
if
fp8_attention
:
ql_nope_shape
=
decode_ql_nope
.
shape
decode_ql_nope
,
_
=
ops
.
scaled_fp8_quant
(
decode_ql_nope
.
reshape
([
ql_nope_shape
[
0
],
ql_nope_shape
[
1
]
*
ql_nope_shape
[
2
]
]),
layer
.
_q_scale
)
decode_ql_nope
=
decode_ql_nope
.
reshape
(
ql_nope_shape
)
q_pe_shape
=
decode_q_pe
.
shape
decode_q_pe
,
_
=
ops
.
scaled_fp8_quant
(
decode_q_pe
.
reshape
(
[
q_pe_shape
[
0
],
q_pe_shape
[
1
]
*
q_pe_shape
[
2
]]),
layer
.
_q_scale
)
decode_q_pe
=
decode_q_pe
.
reshape
(
q_pe_shape
)
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
)
return
output_padded
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
19fe1a05
...
...
@@ -7,7 +7,7 @@ from typing import ClassVar, Optional
import
torch
import
vllm._custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionType
,
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -278,6 +278,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
layer
:
AttentionLayer
,
)
->
torch
.
Tensor
:
if
self
.
_use_old_cutlass_mla
:
# TODO: Remove the old cutlass MLA kernel after more extensive
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
19fe1a05
...
...
@@ -6,8 +6,7 @@ from typing import ClassVar, Optional
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.abstract
import
AttentionLayer
,
AttentionType
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_supported
)
...
...
@@ -166,16 +165,13 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA V1 with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
layer
:
AttentionLayer
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
...
...
@@ -194,6 +190,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
descale_q
=
layer
.
_q_scale
.
reshape
(
1
),
descale_k
=
layer
.
_k_scale
.
reshape
(
1
),
)
return
self
.
_v_up_proj
(
o
)
vllm/v1/attention/backends/mla/rocm_aiter_mla.py
View file @
19fe1a05
...
...
@@ -7,6 +7,7 @@ from typing import ClassVar, Optional
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionLayer
from
vllm.attention.ops.rocm_aiter_mla
import
aiter_mla_decode_fwd
from
vllm.config
import
VllmConfig
from
vllm.utils
import
cdiv
...
...
@@ -221,6 +222,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
AiterMLAMetadata
,
layer
:
AttentionLayer
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
19fe1a05
...
...
@@ -6,7 +6,7 @@ from typing import Optional
import
torch
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionType
,
from
vllm.attention.backends.abstract
import
(
AttentionLayer
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
...
...
@@ -127,6 +127,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
layer
:
AttentionLayer
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment