Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1acf2d7a
Commit
1acf2d7a
authored
Dec 18, 2025
by
zhuwenwen
Browse files
update get_mla_decoding_metadata_dense_fp8 interface and _k_scale&_v_scale
parent
77210184
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
26 deletions
+68
-26
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+46
-18
vllm/attention/layer.py
vllm/attention/layer.py
+6
-2
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+16
-6
No files found.
vllm/attention/backends/flashmla.py
View file @
1acf2d7a
...
@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
...
@@ -17,6 +17,7 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm
import
envs
from
vllm
import
envs
...
@@ -89,12 +90,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -89,12 +90,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
batch_size
)
batch_size
)
if
m
.
num_decode_tokens
>
0
:
if
m
.
num_decode_tokens
>
0
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
get_mla_metadata
(
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
get_mla_decoding_metadata_dense_fp8
(
self
.
num_q_heads
,
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
1
,
# MQA for the decode path
self
.
num_q_heads
,
)
1
,
# MQA for the decode path
16
,
)
else
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
return
m
return
m
...
@@ -109,13 +119,23 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
...
@@ -109,13 +119,23 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
@
contextmanager
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
def
graph_capture
(
self
,
max_batch_size
:
int
):
# Run a dummy `get_mla_metadata` so we can get the right shapes
# Run a dummy `get_mla_metadata` so we can get the right shapes
self
.
_graph_decoder_tile_scheduler_metadata
,
\
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
self
.
_graph_decoder_tile_scheduler_metadata
,
\
torch
.
ones
(
self
.
_graph_decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
torch
.
ones
(
self
.
num_q_heads
,
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
1
,
# MQA for the decode path
self
.
num_q_heads
,
)
1
,
# MQA for the decode path
16
,
)
else
:
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
with
super
().
graph_capture
(
max_batch_size
):
with
super
().
graph_capture
(
max_batch_size
):
yield
yield
...
@@ -129,11 +149,19 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
...
@@ -129,11 +149,19 @@ class FlashMLAState(MLACommonState[FlashMLAMetadata]):
batch_size
,
is_encoder_decoder_model
)
batch_size
,
is_encoder_decoder_model
)
assert
metadata
.
num_decode_tokens
>
0
assert
metadata
.
num_decode_tokens
>
0
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_graph_seq_lens
[:
batch_size
],
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_decoding_metadata_dense_fp8
(
self
.
num_q_heads
,
self
.
_graph_seq_lens
[:
batch_size
],
1
,
# MQA for the decode path
self
.
num_q_heads
,
)
1
,
# MQA for the decode path
16
,
)
else
:
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
self
.
_graph_decoder_tile_scheduler_metadata
.
copy_
(
self
.
_graph_decoder_tile_scheduler_metadata
.
copy_
(
decoder_tile_scheduler_metadata
)
decoder_tile_scheduler_metadata
)
...
...
vllm/attention/layer.py
View file @
1acf2d7a
...
@@ -98,8 +98,12 @@ class Attention(nn.Module):
...
@@ -98,8 +98,12 @@ class Attention(nn.Module):
# with the model weights.
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
self
.
calculate_kv_scales
=
calculate_kv_scales
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
kv_cache_dtype
==
"fp8_e4m3"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_k_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
ones
((
1
),
dtype
=
torch
.
float32
)
else
:
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# FlashAttn doesn't support quantizing the kv-cache only
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
# but requires q to be quantized as well.
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_q_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
1acf2d7a
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
...
@@ -12,6 +12,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe
,
flash_mla_with_kvcache_q_nope_pe
,
get_mla_metadata
,
get_mla_metadata
,
flash_mla_with_kvcache_fp8
,
flash_mla_with_kvcache_fp8
,
get_mla_decoding_metadata_dense_fp8
,
is_flashmla_supported
)
is_flashmla_supported
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
@@ -72,12 +73,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -72,12 +73,21 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
def
_build_decode
(
self
,
block_table_tensor
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
seq_lens
:
torch
.
Tensor
)
->
FlashMLADecodeMetadata
:
tile_scheduler_metadata
,
num_splits
=
\
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
and
envs
.
VLLM_USE_FLASH_MLA_FP8
:
get_mla_metadata
(
tile_scheduler_metadata
,
num_splits
=
\
seq_lens
,
get_mla_decoding_metadata_dense_fp8
(
self
.
num_q_heads
,
seq_lens
,
1
,
# MQA for the decode path
self
.
num_q_heads
,
)
1
,
# MQA for the decode path
16
,
)
else
:
tile_scheduler_metadata
,
num_splits
=
\
get_mla_metadata
(
seq_lens
,
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
if
self
.
runner
.
full_cuda_graph
:
if
self
.
runner
.
full_cuda_graph
:
# First time around (CUDAGraph capture), allocate the static buffer
# First time around (CUDAGraph capture), allocate the static buffer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment