Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
183dad7a
Unverified
Commit
183dad7a
authored
Apr 17, 2025
by
Lucas Wilkinson
Committed by
GitHub
Apr 17, 2025
Browse files
[Attention] Update to lastest FA3 code (#13111)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
3408e471
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
241 additions
and
118 deletions
+241
-118
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+92
-90
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+25
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+58
-1
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+65
-25
No files found.
cmake/external_projects/vllm_flash_attn.cmake
View file @
183dad7a
...
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
dc9d410b3e2d6534a4c70724c2515f4def670a22
GIT_TAG
0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
vllm/attention/backends/mla/common.py
View file @
183dad7a
...
...
@@ -1043,8 +1043,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self
.
q_proj
=
q_proj
self
.
kv_b_proj
=
kv_b_proj
self
.
o_proj
=
o_proj
self
.
triton_fa_func
=
triton_attention
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
...
...
@@ -1055,6 +1055,70 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
**
kwargs
,
)
if
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
...
@@ -1176,40 +1240,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
else
:
attn_output
,
attn_softmax_lse
,
_
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_attn_probs
=
True
,
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
if
output
is
None
:
output
=
attn_output
...
...
@@ -1252,58 +1295,22 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
output
=
self
.
triton_fa_func
(
q
,
k
,
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
max_prefill_seq_len
,
prefill_metadata
.
max_prefill_seq_len
,
True
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if
not
has_context
:
output
=
output
[
0
]
elif
is_vllm_fa
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
else
:
output
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v_padded
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_attn_probs
=
has_context
,
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
,
*
rest
=
output
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
)
...
...
@@ -1316,12 +1323,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse
=
suffix_lse
,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
@
abstractmethod
def
_forward_decode
(
...
...
vllm/attention/backends/utils.py
View file @
183dad7a
...
...
@@ -2,8 +2,10 @@
"""Attention backend utils"""
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Tuple
,
Type
,
TypeVar
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
)
import
numpy
as
np
import
torch
...
...
@@ -11,6 +13,7 @@ import torch
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
)
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
...
...
@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
return
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
@
dataclass
class
MLADims
:
q_lora_rank
:
Optional
[
int
]
kv_lora_rank
:
int
qk_nope_head_dim
:
int
qk_rope_head_dim
:
int
v_head_dim
:
int
def
get_mla_dims
(
model_config
:
ModelConfig
)
->
MLADims
:
hf_text_config
=
model_config
.
hf_text_config
return
MLADims
(
q_lora_rank
=
getattr
(
hf_text_config
,
"q_lora_rank"
,
None
),
kv_lora_rank
=
hf_text_config
.
kv_lora_rank
,
qk_nope_head_dim
=
hf_text_config
.
qk_nope_head_dim
,
qk_rope_head_dim
=
hf_text_config
.
qk_rope_head_dim
,
v_head_dim
=
hf_text_config
.
v_head_dim
,
)
vllm/v1/attention/backends/flash_attn.py
View file @
183dad7a
...
...
@@ -23,7 +23,8 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
get_scheduler_metadata
)
logger
=
init_logger
(
__name__
)
...
...
@@ -93,6 +94,10 @@ class FlashAttentionMetadata:
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
# Optional aot scheduling
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
# For logging.
num_input_tokens
:
int
=
0
# Number of tokens including padding.
...
...
@@ -277,7 +282,14 @@ def make_local_attention_virtual_batches(
class
FlashAttentionMetadataBuilder
:
def
__init__
(
self
,
runner
:
"GPUModelRunner"
):
model_config
=
runner
.
model_config
self
.
runner
=
runner
self
.
aot_schedule
=
(
get_flash_attn_version
()
==
3
)
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
self
.
headdim
=
model_config
.
get_head_size
()
self
.
page_size
=
self
.
runner
.
block_size
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
...
@@ -319,6 +331,24 @@ class FlashAttentionMetadataBuilder:
)
use_cascade
=
common_prefix_len
>
0
def
schedule
(
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
if
self
.
aot_schedule
:
return
get_scheduler_metadata
(
batch_size
=
num_reqs
,
max_seqlen_q
=
max_query_len
,
max_seqlen_k
=
max_seq_len
,
cache_seqlens
=
seqlens
,
num_heads_q
=
self
.
num_heads
,
num_heads_kv
=
self
.
num_heads
,
headdim
=
self
.
headdim
,
page_size
=
self
.
page_size
,
cu_seqlens_q
=
cu_query_lens
,
causal
=
causal
,
)
return
None
if
use_cascade
:
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
...
...
@@ -330,10 +360,28 @@ class FlashAttentionMetadataBuilder:
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
runner
.
device
)
prefix_scheduler_metadata
=
schedule
(
cu_query_lens
=
cu_prefix_query_lens
,
max_query_len
=
num_actual_tokens
,
seqlens
=
prefix_kv_lens
,
max_seq_len
=
common_prefix_len
,
causal
=
False
)
scheduler_metadata
=
schedule
(
cu_query_lens
=
query_start_loc
,
max_query_len
=
max_query_len
,
seqlens
=
suffix_kv_lens
,
max_seq_len
=
max_seq_len
-
common_prefix_len
,
causal
=
True
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
prefix_scheduler_metadata
=
None
scheduler_metadata
=
schedule
(
cu_query_lens
=
query_start_loc
,
max_query_len
=
max_query_len
,
seqlens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
causal
=
True
)
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
...
...
@@ -345,10 +393,12 @@ class FlashAttentionMetadataBuilder:
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
scheduler_metadata
=
scheduler_metadata
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
)
return
attn_metadata
...
...
@@ -515,6 +565,7 @@ class FlashAttentionImpl(AttentionImpl):
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
softcap
=
self
.
logits_soft_cap
,
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
...
...
@@ -543,6 +594,8 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
,
common_prefix_len
=
attn_metadata
.
common_prefix_len
,
fa_version
=
self
.
vllm_flash_attn_version
,
prefix_scheduler_metadata
=
attn_metadata
.
prefix_scheduler_metadata
,
suffix_scheduler_metadata
=
attn_metadata
.
scheduler_metadata
,
q_descale
=
layer
.
_q_scale
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
...
...
@@ -636,6 +689,8 @@ def cascade_attention(
block_table
:
torch
.
Tensor
,
common_prefix_len
:
int
,
fa_version
:
int
,
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
,
suffix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
,
q_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
k_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
v_descale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -667,6 +722,7 @@ def cascade_attention(
block_table
=
block_table
[:
1
],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
...
...
@@ -693,6 +749,7 @@ def cascade_attention(
block_table
=
block_table
[:,
num_common_kv_blocks
:],
softcap
=
logits_soft_cap
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
fa_version
=
fa_version
,
q_descale
=
q_descale
.
expand
(
descale_shape
)
if
q_descale
is
not
None
else
None
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
183dad7a
...
...
@@ -195,6 +195,7 @@ from vllm import _custom_ops as ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
get_mla_dims
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
...
@@ -207,9 +208,11 @@ from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
True
except
ImportError
:
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
False
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -350,6 +353,14 @@ class MLACommonMetadataBuilder(Generic[M]):
model_config
=
runner
.
model_config
cache_config
=
runner
.
cache_config
self
.
chunked_prefill_enabled
=
scheduler_config
.
chunked_prefill_enabled
self
.
num_heads
=
model_config
.
get_num_attention_heads
(
runner
.
parallel_config
)
self
.
mla_dims
=
get_mla_dims
(
model_config
)
self
.
aot_schedule
=
is_vllm_fa
and
(
get_flash_attn_version
()
==
3
)
# Dont try to access the runner on AMD
if
self
.
aot_schedule
:
self
.
page_size
=
self
.
runner
.
block_size
if
self
.
chunked_prefill_enabled
:
self
.
chunked_prefill_workspace_size
=
min
(
...
...
@@ -375,7 +386,6 @@ class MLACommonMetadataBuilder(Generic[M]):
dtype
=
model_config
.
dtype
,
device
=
runner
.
device
,
)
self
.
page_size
=
self
.
runner
.
block_size
def
reorder_batch
(
self
,
input_batch
:
"InputBatch"
,
scheduler_output
:
"SchedulerOutput"
)
->
bool
:
...
...
@@ -464,7 +474,6 @@ class MLACommonMetadataBuilder(Generic[M]):
seq_lens_cpu
=
self
.
runner
.
seq_lens_cpu
[:
num_reqs
]
seq_lens
=
seq_lens_cpu
.
to
(
device
,
non_blocking
=
True
)
max_query_len
=
seq_lens_cpu
.
max
().
item
()
prefill_metadata
=
None
if
self
.
_num_prefills
>
0
:
...
...
@@ -475,6 +484,8 @@ class MLACommonMetadataBuilder(Generic[M]):
num_computed_tokens_cpu_tensor
[
reqs_start
:
num_reqs
]
max_context_len_cpu
=
context_lens_cpu
.
max
().
item
()
num_prefills_with_context_cpu
=
(
context_lens_cpu
>
0
).
sum
().
item
()
prefill_query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
]
chunked_context_metadata
=
None
if
self
.
chunked_prefill_enabled
and
self
.
_num_prefills
>
0
\
...
...
@@ -537,8 +548,7 @@ class MLACommonMetadataBuilder(Generic[M]):
prefill_metadata
=
MLACommonPrefillMetadata
(
input_positions
=
input_positions
[
tokens_start
:],
block_table
=
block_table
[
reqs_start
:,
...],
query_start_loc
=
query_start_loc
[
reqs_start
:]
-
query_start_loc
[
reqs_start
],
query_start_loc
=
prefill_query_start_loc
,
max_query_len
=
max_query_len
,
chunked_context
=
chunked_context_metadata
,
)
...
...
@@ -628,11 +638,56 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
return_softmax_lse
=
False
,
softmax_scale
=
None
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results
lse
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
lse
=
attn_out
[
0
],
attn_out
[
1
]
# unpad if necessary
if
self
.
_pad_v
:
attn_out
=
attn_out
[...,
:
v
.
shape
[
-
1
]]
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
return
attn_out
,
lse
return
attn_out
def
_v_up_proj_and_o_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
...
...
@@ -745,16 +800,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
_padded
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
chunked_context
.
cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
...
...
@@ -801,15 +851,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
output
=
self
.
flash_attn_varlen_func
(
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
_padded
,
v
=
v
,
cu_seqlens_q
=
attn_metadata
.
prefill
.
query_start_loc
,
cu_seqlens_k
=
attn_metadata
.
prefill
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
prefill
.
max_query_len
,
...
...
@@ -833,12 +878,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
suffix_lse
=
suffix_lse
,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
output
=
output
\
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
\
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
return
self
.
o_proj
(
output
.
flatten
(
start_dim
=-
2
))[
0
]
@
abstractmethod
def
_forward_decode
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment