Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
183dad7a
Unverified
Commit
183dad7a
authored
Apr 17, 2025
by
Lucas Wilkinson
Committed by
GitHub
Apr 17, 2025
Browse files
[Attention] Update to lastest FA3 code (#13111)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
3408e471
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
118 deletions
+241
-118
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+92
-90
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+25
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+58
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+65
-25
No files found.
cmake/external_projects/vllm_flash_attn.cmake
View file @
183dad7a
...
@@ -38,7 +38,7 @@ else()
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
FetchContent_Declare
(
vllm-flash-attn
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
dc9d410b3e2d6534a4c70724c2515f4def670a22
GIT_TAG
0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
GIT_PROGRESS TRUE
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
vllm/attention/backends/mla/common.py
View file @
183dad7a
...
@@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# latter has an additional parameter to control FA2 vs FA3
...
@@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
functools
.
partial
(
flash_attn_varlen_func
,
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
**
kwargs
,
)
if
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
@@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
attn_output
,
attn_softmax_lse
=
\
# out v with 0s to match the qk head dim
self
.
_flash_attn_varlen_diff_headdims
(
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v_padded
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
)
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
if
output
is
None
:
if
output
is
None
:
output
=
attn_output
output
=
attn_output
...
@@ -1252,33 +1295,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1252,33 +1295,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
output
=
self
.
_flash_attn_varlen_diff_headdims
(
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
output
=
self
.
triton_fa_func
(
q
,
k
,
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v_padded
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
...
@@ -1287,23 +1307,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1287,23 +1307,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
causal
=
True
,
causal
=
True
,
return_softmax_lse
=
has_context
,
return_softmax_lse
=
has_context
,
)
)
else
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
if
has_context
:
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
@@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
...
...
vllm/attention/backends/utils.py
View file @
183dad7a
...
@@ -2,8 +2,10 @@
...
@@ -2,8 +2,10 @@
"""Attention backend utils"""
"""Attention backend utils"""
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -11,6 +13,7 @@ import torch
...
@@ -11,6 +13,7 @@ import torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
...
@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
...
@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
num_decode_query_tokens
)
@
dataclass
class
MLADims
:
q_lora_rank
:
Optional
[
int
]
kv_lora_rank
:
int
qk_nope_head_dim
:
int
qk_rope_head_dim
:
int
v_head_dim
:
int
def
get_mla_dims
(
model_config
:
ModelConfig
)
->
MLADims
:
hf_text_config
=
model_config
.
hf_text_config
return
MLADims
(
q_lora_rank
=
getattr
(
hf_text_config
,
"q_lora_rank"
,
None
),
kv_lora_rank
=
hf_text_config
.
kv_lora_rank
,
qk_nope_head_dim
=
hf_text_config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
hf_text_config
.
qk_rope_head_dim
,
v_head_dim
=
hf_text_config
.
v_head_dim
,
)
vllm/v1/attention/backends/flash_attn.py
View file @
183dad7a
...
@@ -23,7 +23,8 @@ if TYPE_CHECKING:
...
@@ -23,7 +23,8 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_cuda
():
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
get_scheduler_metadata
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -93,6 +94,10 @@ class FlashAttentionMetadata:
...
@@ -93,6 +94,10 @@ class FlashAttentionMetadata:
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
# Optional aot scheduling
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
# For logging.
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
...
@@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
...
@@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
class
FlashAttentionMetadataBuilder
:
class
FlashAttentionMetadataBuilder
:
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
model_config
=
runner
.
model_config
self
.
runner
=
runner
self
.
runner
=
runner
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
self
.
headdim
=
model_config
.
get_head_size
()
self
.
page_size
=
self
.
runner
.
block_size
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
@@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
...
@@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
)
)
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
def
schedule
(
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
if
self
.
aot_schedule
:
return
get_scheduler_metadata
(
batch_size
=
num_reqs
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_seq_len
,
cache_seqlens
=
seqlens
,
num_heads_q
=
self
.
num_heads
,
num_heads_kv
=
self
.
num_heads
,
headdim
=
self
.
headdim
,
page_size
=
self
.
page_size
,
cu_seqlens_q
=
cu_query_lens
,
causal
=
causal
,
)
return
None
if
use_cascade
:
if
use_cascade
:
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
...
@@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
common_prefix_len
)
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
runner
.
device
)
self
.
runner
.
device
)
prefix_scheduler_metadata
=
schedule
(
cu_query_lens
=
cu_prefix_query_lens
,
max_query_len
=
num_actual_tokens
,
seqlens
=
prefix_kv_lens
,
max_seq_len
=
common_prefix_len
,
causal
=
False
)
scheduler_metadata
=
schedule
(
cu_query_lens
=
query_start_loc
,
max_query_len
=
max_query_len
,
seqlens
=
suffix_kv_lens
,
max_seq_len
=
max_seq_len
-
common_prefix_len
,
causal
=
True
)
else
:
else
:
cu_prefix_query_lens
=
None
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
suffix_kv_lens
=
None
prefix_scheduler_metadata
=
None
scheduler_metadata
=
schedule
(
cu_query_lens
=
query_start_loc
,
max_query_len
=
max_query_len
,
seqlens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
...
@@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
...
@@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
scheduler_metadata
=
scheduler_metadata
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
)
)
return
attn_metadata
return
attn_metadata
...
@@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
...
@@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
,
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
...
@@ -636,6 +689,8 @@ def cascade_attention(
...
@@ -636,6 +689,8 @@ def cascade_attention(
block_table
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
common_prefix_len
:
int
,
fa_version
:
int
,
fa_version
:
int
,
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
,
suffix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -667,6 +722,7 @@ def cascade_attention(
...
@@ -667,6 +722,7 @@ def cascade_attention(
block_table
=
block_table
[:
1
],
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
if
q_descale
is
not
None
else
None
,
...
@@ -693,6 +749,7 @@ def cascade_attention(
...
@@ -693,6 +749,7 @@ def cascade_attention(
block_table
=
block_table
[:,
num_common_kv_blocks
:],
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
if
q_descale
is
not
None
else
None
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
183dad7a
...
@@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
...
@@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadata
,
MLAAttentionImpl
)
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
...
@@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try
:
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
True
except
ImportError
:
except
ImportError
:
# For rocm use upstream flash attention
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
False
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
model_config
=
runner
.
model_config
model_config
=
runner
.
model_config
cache_config
=
runner
.
cache_config
cache_config
=
runner
.
cache_config
self
.
chunked_prefill_enabled
=
scheduler_config
.
chunked_prefill_enabled
self
.
chunked_prefill_enabled
=
scheduler_config
.
chunked_prefill_enabled
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
model_config
)
self
.
aot_schedule
=
is_vllm_fa
and
(
get_flash_attn_version
()
==
3
)
# Dont try to access the runner on AMD
if
self
.
aot_schedule
:
self
.
page_size
=
self
.
runner
.
block_size
if
self
.
chunked_prefill_enabled
:
if
self
.
chunked_prefill_enabled
:
self
.
chunked_prefill_workspace_size
=
min
(
self
.
chunked_prefill_workspace_size
=
min
(
...
@@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
dtype
=
model_config
.
dtype
,
dtype
=
model_config
.
dtype
,
device
=
runner
.
device
,
device
=
runner
.
device
,
)
)
self
.
page_size
=
self
.
runner
.
block_size
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
@@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
seq_lens
=
seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
max_query_len
=
seq_lens_cpu
.
max
().
item
()
prefill_metadata
=
None
prefill_metadata
=
None
if
self
.
_num_prefills
>
0
:
if
self
.
_num_prefills
>
0
:
...
@@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
num_computed_tokens_cpu_tensor
[
reqs_start
:
num_reqs
]
num_computed_tokens_cpu_tensor
[
reqs_start
:
num_reqs
]
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
prefill_query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
]
chunked_context_metadata
=
None
chunked_context_metadata
=
None
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
...
@@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
...
@@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
prefill_metadata
=
MLACommonPrefillMetadata
(
prefill_metadata
=
MLACommonPrefillMetadata
(
input_positions
=
input_positions
[
tokens_start
:],
input_positions
=
input_positions
[
tokens_start
:],
block_table
=
block_table
[
reqs_start
:,
...],
block_table
=
block_table
[
reqs_start
:,
...],
query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
=
prefill_query_start_loc
,
query_start_loc
[
reqs_start
],
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
chunked_context
=
chunked_context_metadata
,
)
)
...
@@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# and the one from vllm_flash_attn. The former is used on RoCM and the
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results
lse
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
lse
=
attn_out
[
0
],
attn_out
[
1
]
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
return
attn_out
,
lse
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
@@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
attn_output
,
attn_softmax_lse
=
\
# out v with 0s to match the qk head dim
self
.
_flash_attn_varlen_diff_headdims
(
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
_padded
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
...
@@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
output
=
self
.
_flash_attn_varlen_diff_headdims
(
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
_padded
,
v
=
v
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
...
@@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse
=
suffix_lse
,
suffix_lse
=
suffix_lse
,
)
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
@
abstractmethod
@
abstractmethod
def
_forward_decode
(
def
_forward_decode
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment