Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b032352
Unverified
Commit
5b032352
authored
Jul 10, 2025
by
Alexander Matveev
Committed by
GitHub
Jul 10, 2025
Browse files
[Attention] MLA - Flashinfer Ragged Prefill (#20034)
parent
922f3164
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
422 additions
and
215 deletions
+422
-215
tests/v1/kv_connector/__init__.py
tests/v1/kv_connector/__init__.py
+0
-0
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+15
-72
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+62
-0
vllm/attention/layer.py
vllm/attention/layer.py
+1
-1
vllm/attention/utils/kv_sharing_utils.py
vllm/attention/utils/kv_sharing_utils.py
+33
-0
vllm/logger.py
vllm/logger.py
+14
-0
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+5
-68
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+222
-40
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+1
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+69
-34
No files found.
tests/v1/kv_connector/__init__.py
0 → 100644
View file @
5b032352
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
5b032352
...
...
@@ -3,16 +3,10 @@
import
filecmp
import
shutil
import
tempfile
from
collections
import
defaultdict
from
pathlib
import
Path
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
SharedStorageConnector
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.config
import
KVTransferConfig
MODEL_NAME
=
"meta-llama/Llama-3.2-1B-Instruct"
...
...
@@ -25,65 +19,6 @@ PROMPTS = [
SAMPLING_PARAMS
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
20
)
class
TestSharedStorageConnector
(
SharedStorageConnector
):
def
__init__
(
self
,
config
:
VllmConfig
,
role
):
self
.
name
=
config
.
kv_transfer_config
.
kv_connector_extra_config
[
"name"
]
self
.
_connector
=
SharedStorageConnector
(
config
,
role
)
self
.
call_record
:
dict
[
str
,
int
]
=
defaultdict
(
int
)
# Use a unique temp file per connector
self
.
_event_file
=
tempfile
.
gettempdir
(
)
+
f
"/connector_
{
self
.
name
}
-
{
self
.
role
.
name
}
_events.log"
# Start with an empty file
with
open
(
self
.
_event_file
,
"w"
)
as
_
:
pass
def
__getattribute__
(
self
,
name
):
if
name
in
(
"_connector"
,
"call_record"
,
"name"
,
"_event_file"
,
"__class__"
,
"__dict__"
,
"__getattribute__"
,
"__init__"
):
# avoid recursion
return
object
.
__getattribute__
(
self
,
name
)
if
not
hasattr
(
self
.
_connector
,
name
):
return
object
.
__getattribute__
(
self
,
name
)
attr
=
getattr
(
self
.
_connector
,
name
)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if
callable
(
attr
):
def
wrapper
(
*
args
,
**
kwargs
):
self
.
call_record
[
name
]
+=
1
# Include args that we're interested in
to_log
=
[
name
]
for
arg
in
args
:
if
isinstance
(
arg
,
int
):
to_log
.
append
(
str
(
arg
))
elif
isinstance
(
arg
,
KVCacheBlocks
):
to_log
.
append
(
f
"num_blocks=
{
[
len
(
b
)
for
b
in
arg
.
blocks
]
}
"
)
# Log the event as a line to the file
try
:
with
open
(
self
.
_event_file
,
"a"
)
as
f
:
f
.
write
(
' '
.
join
(
to_log
)
+
"
\n
"
)
except
Exception
as
e
:
print
(
f
"[ERROR] Could not log event
{
name
}
"
f
"for
{
self
.
name
}
:
{
e
}
"
)
return
attr
(
*
args
,
**
kwargs
)
return
wrapper
return
attr
# This relies on "fork" multiprocessing method being used.
# It's the default but vLLM may fall back to spawn if for example CUDA
# is already initialized.
KVConnectorFactory
.
register_connector
(
"TestSharedStorageConnector"
,
TestSharedStorageConnector
.
__module__
,
TestSharedStorageConnector
.
__name__
)
# Helper function to compare directories recursively
def
_compare_directories
(
dir1
:
Path
,
dir2
:
Path
)
->
bool
:
"""Compares two directories recursively for identical content."""
...
...
@@ -118,19 +53,27 @@ def test_multi_shared_storage_connector_consistency():
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"connectors"
:
[{
"kv_connector"
:
"TestSharedStorageConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector"
:
"TestSharedStorageConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_1_path
),
"name"
:
"storage1"
,
}
},
"kv_connector_module_path"
:
"tests.v1.kv_connector.unit.utils"
,
},
{
"kv_connector"
:
"TestSharedStorageConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector"
:
"TestSharedStorageConnector"
,
"kv_role"
:
"kv_both"
,
"kv_connector_extra_config"
:
{
"shared_storage_path"
:
str
(
storage_2_path
),
"name"
:
"storage2"
,
}
},
"kv_connector_module_path"
:
"tests.v1.kv_connector.unit.utils"
,
}]
},
)
...
...
tests/v1/kv_connector/unit/utils.py
View file @
5b032352
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
from
collections
import
defaultdict
from
typing
import
Any
,
Optional
import
torch
...
...
@@ -7,6 +9,11 @@ import torch
from
vllm
import
SamplingParams
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
KVTransferConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
)
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
SharedStorageConnector
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
)
...
...
@@ -187,3 +194,58 @@ def create_model_runner_output(
finished_sending
=
finished_sending
,
finished_recving
=
finished_recving
,
)
class
TestSharedStorageConnector
(
SharedStorageConnector
):
def
__init__
(
self
,
config
:
VllmConfig
,
role
):
self
.
name
=
config
.
kv_transfer_config
.
kv_connector_extra_config
[
"name"
]
self
.
_connector
=
SharedStorageConnector
(
config
,
role
)
self
.
call_record
:
dict
[
str
,
int
]
=
defaultdict
(
int
)
# Use a unique temp file per connector
self
.
_event_file
=
tempfile
.
gettempdir
(
)
+
f
"/connector_
{
self
.
name
}
-
{
self
.
role
.
name
}
_events.log"
# Start with an empty file
with
open
(
self
.
_event_file
,
"w"
)
as
_
:
pass
def
__getattribute__
(
self
,
name
):
if
name
in
(
"_connector"
,
"call_record"
,
"name"
,
"_event_file"
,
"__class__"
,
"__dict__"
,
"__getattribute__"
,
"__init__"
):
# avoid recursion
return
object
.
__getattribute__
(
self
,
name
)
if
not
hasattr
(
self
.
_connector
,
name
):
return
object
.
__getattribute__
(
self
,
name
)
attr
=
getattr
(
self
.
_connector
,
name
)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if
callable
(
attr
):
def
wrapper
(
*
args
,
**
kwargs
):
self
.
call_record
[
name
]
+=
1
# Include args that we're interested in
to_log
=
[
name
]
for
arg
in
args
:
if
isinstance
(
arg
,
int
):
to_log
.
append
(
str
(
arg
))
elif
isinstance
(
arg
,
KVCacheBlocks
):
to_log
.
append
(
f
"num_blocks=
{
[
len
(
b
)
for
b
in
arg
.
blocks
]
}
"
)
# Log the event as a line to the file
try
:
with
open
(
self
.
_event_file
,
"a"
)
as
f
:
f
.
write
(
' '
.
join
(
to_log
)
+
"
\n
"
)
except
Exception
as
e
:
print
(
f
"[ERROR] Could not log event
{
name
}
"
f
"for
{
self
.
name
}
:
{
e
}
"
)
return
attr
(
*
args
,
**
kwargs
)
return
wrapper
return
attr
KVConnectorFactory
.
register_connector
(
"TestSharedStorageConnector"
,
__name__
,
TestSharedStorageConnector
.
__name__
)
vllm/attention/layer.py
View file @
5b032352
...
...
@@ -10,6 +10,7 @@ import torch.nn.functional as F
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.utils.kv_sharing_utils
import
validate_kv_sharing_target
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
,
...
...
@@ -21,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
class
Attention
(
nn
.
Module
):
...
...
vllm/attention/utils/kv_sharing_utils.py
0 → 100644
View file @
5b032352
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
error_msg
=
(
f
"Specified KV sharing target layer for
{
current_layer_name
}
"
f
"is not valid: target layer
{
target_layer_name
}
"
)
if
current_layer_name
==
target_layer_name
:
raise
ValueError
(
error_msg
+
"cannot be the same as the current layer."
)
if
target_layer_name
not
in
static_forward_context
:
from
vllm.model_executor.models.utils
import
extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx
=
extract_layer_index
(
current_layer_name
)
target_layer_idx
=
extract_layer_index
(
target_layer_name
)
if
current_layer_idx
<=
target_layer_idx
:
raise
ValueError
(
error_msg
+
"must come before the current layer."
)
else
:
raise
ValueError
(
error_msg
+
"is not a valid Attention layer in the model."
)
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type
=
static_forward_context
[
target_layer_name
].
attn_type
expected
=
static_forward_context
[
current_layer_name
].
attn_type
if
target_layer_attn_type
!=
expected
:
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
vllm/logger.py
View file @
5b032352
...
...
@@ -53,6 +53,12 @@ DEFAULT_LOGGING_CONFIG = {
}
@
lru_cache
def
_print_debug_once
(
logger
:
Logger
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
# Set the stacklevel to 2 to print the original caller's line info
logger
.
debug
(
msg
,
*
args
,
stacklevel
=
2
)
@
lru_cache
def
_print_info_once
(
logger
:
Logger
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
# Set the stacklevel to 2 to print the original caller's line info
...
...
@@ -74,6 +80,13 @@ class _VllmLogger(Logger):
`intel_extension_for_pytorch.utils._logger`.
"""
def
debug_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
As [`debug`][logging.Logger.debug], but subsequent calls with
the same message are silently dropped.
"""
_print_debug_once
(
self
,
msg
,
*
args
)
def
info_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
As [`info`][logging.Logger.info], but subsequent calls with
...
...
@@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger:
logger
=
logging
.
getLogger
(
name
)
methods_to_patch
=
{
"debug_once"
:
_print_debug_once
,
"info_once"
:
_print_info_once
,
"warning_once"
:
_print_warning_once
,
}
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
5b032352
...
...
@@ -14,13 +14,14 @@ from flashinfer import (BatchDecodeWithPagedKVCacheWrapper,
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionType
)
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.flash_attn
import
use_cascade_attention
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
)
PerLayerParameters
,
get_kv_cache_layout
,
get_per_layer_parameters
,
infer_global_hyperparameters
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -93,70 +94,6 @@ class FlashInferBackend(AttentionBackend):
return
stride_order
@
dataclass
class
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left
:
int
logits_soft_cap
:
Optional
[
float
]
sm_scale
:
float
def
get_per_layer_parameters
(
vllm_config
:
VllmConfig
)
->
dict
[
str
,
PerLayerParameters
]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
impl
=
layer
.
impl
assert
isinstance
(
impl
,
FlashInferImpl
)
# Infer hyperparameters from the attention layer
window_size
=
impl
.
sliding_window
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
logits_soft_cap
=
impl
.
logits_soft_cap
sm_scale
=
impl
.
scale
per_layer_params
[
key
]
=
PerLayerParameters
(
window_left
,
logits_soft_cap
,
sm_scale
)
return
per_layer_params
def
infer_global_hyperparameters
(
per_layer_params
:
dict
[
str
,
PerLayerParameters
])
->
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert
len
(
per_layer_params
)
>
0
,
"No attention layers found in the model."
param_sets
=
list
(
per_layer_params
.
values
())
global_params
=
param_sets
[
0
]
for
params
in
param_sets
:
assert
params
==
global_params
,
(
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`."
)
return
global_params
@
dataclass
class
FlashInferMetadata
:
...
...
@@ -336,7 +273,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def
_plan
(
self
,
attn_metadata
:
FlashInferMetadata
):
if
self
.
global_hyperparameters
is
None
:
self
.
global_hyperparameters
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
self
.
vllm_config
))
get_per_layer_parameters
(
self
.
vllm_config
,
FlashInferImpl
))
if
attn_metadata
.
use_cascade
:
attn_metadata
.
cascade_wrapper
=
self
.
_get_cascade_wrapper
()
attn_metadata
.
cascade_wrapper
.
plan
(
...
...
vllm/v1/attention/backends/mla/common.py
View file @
5b032352
...
...
@@ -189,8 +189,8 @@ return curr_o @ W_O
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Optional
,
TypeVar
,
Union
import
torch
...
...
@@ -208,7 +208,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
)
CommonAttentionMetadata
,
get_per_layer_parameters
,
infer_global_hyperparameters
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
...
...
@@ -221,6 +223,12 @@ except ImportError:
from
flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
False
try
:
from
flashinfer
import
BatchPrefillWithRaggedKVCacheWrapper
flashinfer_available
=
True
except
ImportError
:
flashinfer_available
=
False
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
...
@@ -290,6 +298,13 @@ class MLACommonPrefillMetadata:
chunked_context
:
Optional
[
ChunkedContextMetadata
]
=
None
@
dataclass
class
FlashInferPrefillMetadata
(
MLACommonPrefillMetadata
):
prefill_main
:
Optional
[
'BatchPrefillWithRaggedKVCacheWrapper'
]
=
None
prefill_chunks
:
list
[
'BatchPrefillWithRaggedKVCacheWrapper'
]
=
field
(
default_factory
=
list
)
@
dataclass
class
MLACommonDecodeMetadata
:
block_table
:
torch
.
Tensor
...
...
@@ -328,7 +343,8 @@ class MLACommonMetadata(Generic[D]):
head_dim
:
Optional
[
int
]
=
None
decode
:
Optional
[
D
]
=
None
prefill
:
Optional
[
MLACommonPrefillMetadata
]
=
None
prefill
:
Optional
[
Union
[
MLACommonPrefillMetadata
,
FlashInferPrefillMetadata
]]
=
None
def
__post_init__
(
self
):
if
self
.
head_dim
is
not
None
:
...
...
@@ -338,6 +354,20 @@ class MLACommonMetadata(Generic[D]):
M
=
TypeVar
(
"M"
,
bound
=
MLACommonMetadata
)
def
use_flashinfer_prefill
()
->
bool
:
if
flashinfer_available
:
# For blackwell default to flashinfer prefill if its available since
# its faster than FA2.
return
current_platform
.
has_device_capability
(
100
)
return
False
# Currently 394MB, this can be tuned based on GEMM sizes used.
# Choosen to be the same as sglang:
# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
394
*
1024
*
1024
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
M
]):
"""
NOTE: Please read the comment at the top of the file before trying to
...
...
@@ -392,6 +422,101 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
self
.
block_table
=
block_table
self
.
_use_fi_prefill
=
use_flashinfer_prefill
()
self
.
prefill_metadata_cls
=
FlashInferPrefillMetadata
\
if
self
.
_use_fi_prefill
else
MLACommonPrefillMetadata
if
self
.
_use_fi_prefill
:
self
.
_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
runner
.
device
)
self
.
_fi_prefill_main
:
Optional
[
BatchPrefillWithRaggedKVCacheWrapper
]
=
None
self
.
_fi_prefill_chunks
:
list
[
BatchPrefillWithRaggedKVCacheWrapper
]
=
[]
self
.
_global_hyperparameters
=
infer_global_hyperparameters
(
get_per_layer_parameters
(
runner
.
vllm_config
,
MLACommonImpl
))
def
_build_fi_prefill_wrappers
(
self
,
prefill
:
FlashInferPrefillMetadata
):
qo_indptr
=
prefill
.
query_start_loc
has_context
=
False
if
prefill
.
chunked_context
is
not
None
:
chunked_context
=
prefill
.
chunked_context
has_context
=
True
if
self
.
_fi_prefill_main
is
None
:
self
.
_fi_prefill_main
=
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
_workspace_buffer
,
"NHD"
,
backend
=
"cutlass"
)
if
has_context
:
num_chunks
=
chunked_context
.
cu_seq_lens
.
shape
[
0
]
# Allocate more prefill chunk wrappers if needed
if
len
(
self
.
_fi_prefill_chunks
)
<
num_chunks
:
for
_
in
range
(
len
(
self
.
_fi_prefill_chunks
),
num_chunks
):
self
.
_fi_prefill_chunks
.
append
(
BatchPrefillWithRaggedKVCacheWrapper
(
self
.
_workspace_buffer
,
"NHD"
,
backend
=
"cutlass"
))
assert
num_chunks
<=
len
(
self
.
_fi_prefill_chunks
)
# In MLA, the non-latent num_qo_heads == num_kv_heads
num_qo_heads
=
self
.
runner
.
num_query_heads
num_kv_heads
=
num_qo_heads
# Sanity: Verify that num_kv_heads == 1 since it is latent space
assert
self
.
kv_cache_spec
.
num_kv_heads
==
1
# Get non-latent head_dim_qk and head_dim_vo
head_dim_qk
=
(
self
.
mla_dims
.
qk_nope_head_dim
+
self
.
mla_dims
.
qk_rope_head_dim
)
head_dim_vo
=
self
.
mla_dims
.
v_head_dim
# For main run, qo_indptr == kv_indptr
kv_indptr
=
qo_indptr
.
clone
()
# Prepare main prefill
self
.
_fi_prefill_main
.
plan
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
kv_indptr
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim_qk
=
head_dim_qk
,
head_dim_vo
=
head_dim_vo
,
causal
=
True
,
# This is main run
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
runner
.
dtype
,
kv_data_type
=
self
.
kv_cache_spec
.
dtype
,
)
# Prepare context prefills
if
has_context
:
for
i
in
range
(
num_chunks
):
kv_indptr_chunk
=
chunked_context
.
cu_seq_lens
[
i
]
self
.
_fi_prefill_chunks
[
i
].
plan
(
qo_indptr
=
qo_indptr
,
kv_indptr
=
kv_indptr_chunk
,
num_qo_heads
=
num_qo_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim_qk
=
head_dim_qk
,
head_dim_vo
=
head_dim_vo
,
causal
=
False
,
# This is context run
sm_scale
=
self
.
_global_hyperparameters
.
sm_scale
,
window_left
=
self
.
_global_hyperparameters
.
window_left
,
logits_soft_cap
=
self
.
_global_hyperparameters
.
logits_soft_cap
,
q_data_type
=
self
.
runner
.
dtype
,
kv_data_type
=
self
.
kv_cache_spec
.
dtype
,
)
prefill
.
prefill_main
=
self
.
_fi_prefill_main
prefill
.
prefill_chunks
=
self
.
_fi_prefill_chunks
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
# We now want to reorder the batch so that the "decode" requests are and
...
...
@@ -572,7 +697,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
assert
max
(
chunked_context_metadata
.
max_seq_lens
)
<=
\
self
.
chunked_prefill_workspace_size
prefill_metadata
=
MLACommonP
refill
M
etadata
(
prefill_metadata
=
self
.
p
refill
_m
etadata
_cls
(
block_table
=
block_table_tensor
[
reqs_start
:,
...],
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
max_query_len
,
...
...
@@ -586,7 +711,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens
=
seq_lens
[:
self
.
_num_decodes
],
)
return
self
.
metadata_cls
(
attn_metadata
=
self
.
metadata_cls
(
num_actual_tokens
=
num_actual_tokens
,
query_start_loc
=
query_start_loc
,
slot_mapping
=
slot_mapping
,
...
...
@@ -599,6 +724,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode
=
decode_metadata
,
)
if
self
.
_use_fi_prefill
and
self
.
_num_prefills
>
0
:
assert
isinstance
(
attn_metadata
.
prefill
,
FlashInferPrefillMetadata
)
self
.
_build_fi_prefill_wrappers
(
attn_metadata
.
prefill
)
return
attn_metadata
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
return
common_attn_metadata
.
max_query_len
==
1
...
...
@@ -649,9 +780,20 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self
.
v_head_dim
=
v_head_dim
self
.
kv_b_proj
=
kv_b_proj
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
if
use_flashinfer_prefill
():
logger
.
debug_once
(
"Using FlashInfer prefill for MLA"
)
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fi
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_fi
self
.
_pad_v
=
False
else
:
# Use FlashAttention
logger
.
debug_once
(
"Using FlashAttention prefill for MLA"
)
self
.
_run_prefill_context_chunk
=
self
.
_run_prefill_context_chunk_fa
self
.
_run_prefill_new_tokens
=
self
.
_run_prefill_new_tokens_fa
# Handle the differences between the flash_attn_varlen from
# flash_attn and the one from vllm_flash_attn. The former is used on
# RoCM and the latter has an additional parameter to control
# FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
...
...
@@ -705,6 +847,58 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return
attn_out
,
lse
return
attn_out
def
_run_prefill_new_tokens_fa
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
query_start_loc
,
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
return_softmax_lse
,
)
def
_run_prefill_new_tokens_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
q
,
k
,
v
,
return_softmax_lse
):
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
assert
prefill
.
prefill_main
is
not
None
return
prefill
.
prefill_main
.
run
(
q
=
q
,
k
=
k
,
v
=
v
,
return_lse
=
return_softmax_lse
,
)
def
_run_prefill_context_chunk_fa
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
assert
prefill
.
chunked_context
is
not
None
return
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill
.
query_start_loc
,
cu_seqlens_k
=
prefill
.
chunked_context
.
cu_seq_lens
[
chunk_idx
],
max_seqlen_q
=
prefill
.
max_query_len
,
max_seqlen_k
=
prefill
.
chunked_context
.
max_seq_lens
[
chunk_idx
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
def
_run_prefill_context_chunk_fi
(
self
,
prefill
:
MLACommonPrefillMetadata
,
chunk_idx
:
int
,
q
,
k
,
v
):
assert
isinstance
(
prefill
,
FlashInferPrefillMetadata
)
return
prefill
.
prefill_chunks
[
chunk_idx
].
run
(
q
=
q
,
k
=
k
,
v
=
v
,
return_lse
=
True
,
)
def
_v_up_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
...
@@ -803,18 +997,12 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
attn_output
,
attn_softmax_lse
=
self
.
_run_prefill_context_chunk
(
prefill
=
prefill_metadata
,
chunk_idx
=
i
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
chunked_context
.
max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
if
output
is
None
:
...
...
@@ -854,16 +1042,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
output
=
self
.
_run_prefill_new_tokens
(
prefill
=
attn_metadata
.
prefill
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
prefill
.
max_query_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
...
...
@@ -908,7 +1091,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
5b032352
...
...
@@ -91,6 +91,7 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope
=
q_nope
.
clone
()
q_pe
=
q_pe
.
clone
()
ops
.
cutlass_mla_decode
(
o
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
attn_metadata
.
decode
.
seq_lens
,
attn_metadata
.
decode
.
block_table
,
self
.
scale
)
...
...
vllm/v1/attention/backends/utils.py
View file @
5b032352
...
...
@@ -4,14 +4,17 @@ import abc
import
functools
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
TypeVar
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
Optional
,
TypeVar
import
numpy
as
np
import
torch
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.utils
import
cdiv
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionImpl
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
...
@@ -98,39 +101,6 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return
False
def
validate_kv_sharing_target
(
current_layer_name
,
target_layer_name
,
static_forward_context
):
error_msg
=
(
f
"Specified KV sharing target layer for
{
current_layer_name
}
"
f
"is not valid: target layer
{
target_layer_name
}
"
)
if
current_layer_name
==
target_layer_name
:
raise
ValueError
(
error_msg
+
"cannot be the same as the current layer."
)
if
target_layer_name
not
in
static_forward_context
:
from
vllm.model_executor.models.utils
import
extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx
=
extract_layer_index
(
current_layer_name
)
target_layer_idx
=
extract_layer_index
(
target_layer_name
)
if
current_layer_idx
<=
target_layer_idx
:
raise
ValueError
(
error_msg
+
"must come before the current layer."
)
else
:
raise
ValueError
(
error_msg
+
"is not a valid Attention layer in the model."
)
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type
=
static_forward_context
[
target_layer_name
].
attn_type
expected
=
static_forward_context
[
current_layer_name
].
attn_type
if
target_layer_attn_type
!=
expected
:
raise
ValueError
(
error_msg
+
f
"must be the same type as the current layer (
{
expected
}
)."
)
@
functools
.
lru_cache
def
get_kv_cache_layout
():
# Override with format specified by the user.
...
...
@@ -144,6 +114,71 @@ def get_kv_cache_layout():
return
cache_layout
@
dataclass
class
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
"""
window_left
:
int
logits_soft_cap
:
Optional
[
float
]
sm_scale
:
float
def
get_per_layer_parameters
(
vllm_config
:
VllmConfig
,
cls_
:
type
[
'AttentionImpl'
])
->
dict
[
str
,
PerLayerParameters
]:
"""
Scan all attention layers and determine some hyperparameters
to use during `plan`.
"""
layers
=
get_layers_from_vllm_config
(
vllm_config
,
Attention
)
per_layer_params
:
dict
[
str
,
PerLayerParameters
]
=
{}
for
key
,
layer
in
layers
.
items
():
impl
=
layer
.
impl
assert
isinstance
(
impl
,
cls_
)
# Infer hyperparameters from the attention layer
window_size
=
getattr
(
impl
,
"sliding_window"
,
None
)
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
logits_soft_cap
=
getattr
(
impl
,
"logits_soft_cap"
,
None
)
sm_scale
=
impl
.
scale
per_layer_params
[
key
]
=
PerLayerParameters
(
window_left
,
logits_soft_cap
,
sm_scale
)
return
per_layer_params
def
infer_global_hyperparameters
(
per_layer_params
:
dict
[
str
,
PerLayerParameters
])
->
PerLayerParameters
:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
- `sm_scale`
So this function asserts that all layers share the same values for these
hyperparameters and returns the global values.
"""
assert
len
(
per_layer_params
)
>
0
,
"No attention layers found in the model."
param_sets
=
list
(
per_layer_params
.
values
())
global_params
=
param_sets
[
0
]
for
params
in
param_sets
:
assert
params
==
global_params
,
(
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`."
)
return
global_params
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment