Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bc6e542d
Unverified
Commit
bc6e542d
authored
Sep 21, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 21, 2025
Browse files
Remove V0 attention backends (#25351)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
af7dfb0d
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
76 additions
and
7199 deletions
+76
-7199
examples/offline_inference/qwen_1m.py
examples/offline_inference/qwen_1m.py
+0
-1
tests/compile/test_fusion_attn.py
tests/compile/test_fusion_attn.py
+3
-2
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+3
-3
tests/kernels/attention/test_attention_selector.py
tests/kernels/attention/test_attention_selector.py
+1
-0
tests/kernels/attention/test_prefix_prefill.py
tests/kernels/attention/test_prefix_prefill.py
+3
-3
tests/kernels/attention/test_rocm_attention_selector.py
tests/kernels/attention/test_rocm_attention_selector.py
+1
-0
tests/kernels/utils.py
tests/kernels/utils.py
+56
-10
tests/models/test_initialization.py
tests/models/test_initialization.py
+2
-3
vllm/attention/backends/differential_flash_attn.py
vllm/attention/backends/differential_flash_attn.py
+0
-931
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+0
-1495
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+0
-929
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+0
-227
vllm/attention/backends/mla/__init__.py
vllm/attention/backends/mla/__init__.py
+0
-0
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+0
-1305
vllm/attention/backends/rocm_aiter_mla.py
vllm/attention/backends/rocm_aiter_mla.py
+0
-407
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+0
-953
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+0
-111
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+6
-8
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+0
-805
vllm/config/model.py
vllm/config/model.py
+1
-6
No files found.
examples/offline_inference/qwen_1m.py
View file @
bc6e542d
...
@@ -5,7 +5,6 @@ from urllib.request import urlopen
...
@@ -5,7 +5,6 @@ from urllib.request import urlopen
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
"DUAL_CHUNK_FLASH_ATTN"
os
.
environ
[
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
]
=
"1"
os
.
environ
[
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
]
=
"1"
...
...
tests/compile/test_fusion_attn.py
View file @
bc6e542d
...
@@ -334,8 +334,9 @@ else:
...
@@ -334,8 +334,9 @@ else:
[
7
,
256
,
533
]
if
current_platform
.
is_cuda
()
else
[
8
])
[
7
,
256
,
533
]
if
current_platform
.
is_cuda
()
else
[
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"model_name, model_class"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_name, model_class"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
_Backend
.
FLASHINFER
]
if
@
pytest
.
mark
.
parametrize
(
"backend"
,
current_platform
.
is_cuda
()
else
[
_Backend
.
ROCM_FLASH
])
[
_Backend
.
FLASHINFER
]
if
current_platform
.
is_cuda
()
else
[
_Backend
.
TRITON_ATTN_VLLM_V1
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"split_attention"
,
"split_attention"
,
[
False
,
True
]
if
current_platform
.
is_rocm
()
else
[
False
])
[
False
,
True
]
if
current_platform
.
is_rocm
()
else
[
False
])
...
...
tests/kernels/attention/test_attention.py
View file @
bc6e542d
...
@@ -18,7 +18,7 @@ if not current_platform.is_rocm():
...
@@ -18,7 +18,7 @@ if not current_platform.is_rocm():
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalMask
from
vllm.attention.backends.xformer
s
import
_
make_alibi_bias
from
tests.kernels.util
s
import
make_alibi_bias
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# This will change depending on the compute capability.
...
@@ -429,7 +429,7 @@ def test_multi_query_kv_attention(
...
@@ -429,7 +429,7 @@ def test_multi_query_kv_attention(
alibi_bias
=
None
alibi_bias
=
None
if
use_alibi
:
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
)
attn_bias
=
_
make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
attn_bias
=
make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
seq_lens
)
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
start
=
0
start
=
0
...
...
tests/kernels/attention/test_attention_selector.py
View file @
bc6e542d
...
@@ -67,6 +67,7 @@ def generate_params():
...
@@ -67,6 +67,7 @@ def generate_params():
return
params
return
params
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
@
pytest
.
mark
.
parametrize
(
"device, name, use_mla, block_size"
,
@
pytest
.
mark
.
parametrize
(
"device, name, use_mla, block_size"
,
generate_params
())
generate_params
())
def
test_env
(
def
test_env
(
...
...
tests/kernels/attention/test_prefix_prefill.py
View file @
bc6e542d
...
@@ -11,7 +11,7 @@ import torch
...
@@ -11,7 +11,7 @@ import torch
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalCausalFromBottomRightMask
from
vllm.attention.backends.xformer
s
import
_
make_alibi_bias
from
tests.kernels.util
s
import
make_alibi_bias
from
vllm.attention.ops.chunked_prefill_paged_decode
import
(
from
vllm.attention.ops.chunked_prefill_paged_decode
import
(
chunked_prefill_paged_decode
)
chunked_prefill_paged_decode
)
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
...
@@ -470,7 +470,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -470,7 +470,7 @@ def test_contexted_kv_attention_alibi(
key
=
key
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
attn_bias
=
_
make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
attn_bias
=
make_alibi_bias
(
alibi_slopes
,
num_kv_heads
,
dtype
,
seq_lens
)
output_ref
=
torch
.
empty_like
(
output
)
output_ref
=
torch
.
empty_like
(
output
)
seq_start
=
0
seq_start
=
0
query_start
=
0
query_start
=
0
...
@@ -479,7 +479,7 @@ def test_contexted_kv_attention_alibi(
...
@@ -479,7 +479,7 @@ def test_contexted_kv_attention_alibi(
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# FIXME(DefTruth): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# one. This is inefficient, especially when we have many short prompts.
# modified from: vllm/attention/backends/xformers.py#L343
# modified from: vllm/
v1/
attention/backends/xformers.py#L343
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
for
i
,
(
query_len
,
seq_len
)
in
enumerate
(
zip
(
query_lens
,
seq_lens
)):
seq_end
=
seq_start
+
seq_len
seq_end
=
seq_start
+
seq_len
query_end
=
query_start
+
query_len
query_end
=
query_start
+
query_len
...
...
tests/kernels/attention/test_rocm_attention_selector.py
View file @
bc6e542d
...
@@ -16,6 +16,7 @@ def clear_cache():
...
@@ -16,6 +16,7 @@ def clear_cache():
_cached_get_attn_backend
.
cache_clear
()
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
skip
(
reason
=
"Skipped for now. Should be revisited."
)
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
def
test_selector
(
monkeypatch
:
pytest
.
MonkeyPatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_FLASH"
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
"ROCM_FLASH"
)
...
...
tests/kernels/utils.py
View file @
bc6e542d
...
@@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
...
@@ -513,10 +513,6 @@ def make_backend(backend_name: str) -> AttentionBackend:
Construct the backend instance determined by the backend_name string
Construct the backend instance determined by the backend_name string
argument.
argument.
"XFORMERS" -> construct xformers backend
TODO: other backends
Note: at time of writing the Attention wrapper automatically selects
Note: at time of writing the Attention wrapper automatically selects
its own backend for Attention.forward(); so the backend instance which
its own backend for Attention.forward(); so the backend instance which
you generate with this function is not meant to be used for *running*
you generate with this function is not meant to be used for *running*
...
@@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
...
@@ -528,18 +524,68 @@ def make_backend(backend_name: str) -> AttentionBackend:
* Backend instance
* Backend instance
'''
'''
if
backend_name
==
STR_XFORMERS_ATTN_VAL
:
if
backend_name
in
(
STR_XFORMERS_ATTN_VAL
,
"XFORMERS_VLLM_V1"
)
:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from
vllm.v1.attention.backends.xformers
import
(
from
vllm.attention.backends.xformers
import
XFormers
Backend
XFormersAttention
Backend
)
return
XFormersBackend
()
return
XFormers
Attention
Backend
()
el
if
backend_name
==
STR_FLASH_ATTN_VAL
:
if
backend_name
in
(
STR_FLASH_ATTN_VAL
,
"FLASH_ATTN_VLLM_V1"
)
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.
v1.
attention.backends.flash_attn
import
FlashAttentionBackend
return
FlashAttentionBackend
()
return
FlashAttentionBackend
()
if
backend_name
==
"TRITON_ATTN_VLLM_V1"
:
from
vllm.v1.attention.backends.triton_attn
import
(
TritonAttentionBackend
)
return
TritonAttentionBackend
()
if
backend_name
==
"FLEX_ATTENTION"
:
from
vllm.v1.attention.backends.flex_attention
import
(
FlexAttentionBackend
)
return
FlexAttentionBackend
()
if
backend_name
in
(
"TORCH_SDPA"
,
"TORCH_SDPA_VLLM_V1"
):
from
vllm.v1.attention.backends.cpu_attn
import
TorchSDPABackend
return
TorchSDPABackend
()
if
backend_name
==
"FLASHINFER"
:
from
vllm.v1.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
()
raise
AssertionError
(
raise
AssertionError
(
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
f
"Unrecognized backend_name
{
backend_name
}
for unit test"
)
def
make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
seq_lens
:
list
[
int
],
)
->
list
[
Any
]:
"""Create ALiBi biases compatible with xFormers attention tests."""
from
xformers.ops.fmha.attn_bias
import
LowerTriangularMaskWithTensorBias
if
alibi_slopes
is
None
:
return
[
None
for
_
in
seq_lens
]
attn_biases
:
list
[
Any
]
=
[]
num_heads
=
alibi_slopes
.
shape
[
0
]
assert
num_heads
>=
num_kv_heads
,
(
"ALiBi slopes expect at least as many heads as KV heads"
)
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
,
device
=
alibi_slopes
.
device
)
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
seq_len
+
7
)
//
8
*
8
bias_tensor
=
torch
.
empty
(
1
,
num_heads
,
seq_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias_tensor
.
mul_
(
alibi_slopes
[:,
None
,
None
])
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias_tensor
))
return
attn_biases
def
_make_metadata_tensors
(
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
list
[
int
]],
seq_lens
:
Optional
[
list
[
int
]],
context_lens
:
Optional
[
list
[
int
]],
context_lens
:
Optional
[
list
[
int
]],
...
...
tests/models/test_initialization.py
View file @
bc6e542d
...
@@ -78,9 +78,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
...
@@ -78,9 +78,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
return
return
if
model_arch
in
(
"Phi4FlashForCausalLM"
,
"MotifForCausalLM"
):
if
model_arch
in
(
"Phi4FlashForCausalLM"
,
"MotifForCausalLM"
):
# Phi4FlashForCausalLM and MotifForCausalLM
pytest
.
skip
(
# only supports DIFFERENTIAL_FLASH_ATTN backend
"Differential Flash Attention backend has been removed."
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"DIFFERENTIAL_FLASH_ATTN"
)
if
model_arch
==
"GptOssForCausalLM"
:
if
model_arch
==
"GptOssForCausalLM"
:
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# FIXME: A hack to bypass FA3 assertion because our CI's L4 GPU
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
...
...
vllm/attention/backends/differential_flash_attn.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""" An implementation of https://arxiv.org/pdf/2410.05258 """
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
einops
import
rearrange
from
vllm
import
_custom_ops
as
ops
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
# yapf: enable
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
logger
=
init_logger
(
__name__
)
class
DifferentialFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
=
False
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
assert
num_kv_heads
%
2
==
0
,
"num_kv_heads must be divisible by 2"
return
(
2
,
2
,
num_blocks
,
block_size
,
num_kv_heads
//
2
,
head_size
)
@
staticmethod
def
get_name
()
->
str
:
return
"DIFFERENTIAL_FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"DifferentialFlashAttentionImpl"
]:
return
DifferentialFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"DifferentialFlashAttentionMetadata"
]:
return
DifferentialFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"DifferentialFlashAttentionMetadataBuilder"
]:
return
DifferentialFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
DifferentialFlashAttentionMetadata
(
AttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
"DifferentialFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"DifferentialFlashAttentionMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Cross-layer shared attention block tables
cross_layer_shared_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
is_all_cross_attn_metadata_set
(
self
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"DifferentialFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
((
self
.
seq_lens
is
not
None
)
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
cross_layer_shared_block_tables
=
(
None
if
self
.
cross_layer_shared_block_tables
is
None
else
self
.
cross_layer_shared_block_tables
[:
self
.
num_prefills
])
self
.
_cached_prefill_metadata
=
DifferentialFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
False
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"DifferentialFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
cross_layer_shared_block_tables
=
(
None
if
self
.
cross_layer_shared_block_tables
is
None
else
self
.
cross_layer_shared_block_tables
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
DifferentialFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
class
DifferentialFlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
DifferentialFlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
cross_layer_shared_block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
# TODO: add support for chunked prefill and prefix caching.
assert
not
chunked_prefill_enabled
,
\
"chunked prefill is not supported for now"
assert
not
prefix_cache_hit
,
"prefix caching is not supported for now"
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
cross_layer_shared_block_table
=
[]
if
prefix_cache_hit
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
]
elif
block_tables
is
not
None
:
if
curr_sliding_window_block
==
0
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
]
else
:
cross_layer_shared_block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
cross_layer_shared_block_tables
.
append
(
cross_layer_shared_block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
_get_graph_runner_block_tables
(
self
,
num_seqs
:
int
,
block_tables
:
List
[
List
[
int
]],
graph_block_tables
)
->
torch
.
Tensor
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
# max_batch_size, max_blocks = self.runner.graph_block_tables.shape
max_batch_size
,
max_blocks
=
graph_block_tables
.
shape
assert
max_batch_size
>=
num_seqs
# graph_block_tables = self.runner.graph_block_tables[:num_seqs]
graph_block_tables
=
graph_block_tables
[:
num_seqs
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
graph_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
return
torch
.
from_numpy
(
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
,
non_blocking
=
True
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
max_decode_query_len
=
max
(
decode_query_lens
)
else
:
max_decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
self
.
cross_layer_shared_block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
,
self
.
runner
.
graph_block_tables
)
cross_layer_shared_block_tables
=
\
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
cross_layer_shared_block_tables
,
self
.
runner
.
cross_layer_shared_graph_block_tables
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
cross_layer_shared_block_tables
=
make_tensor_with_pad
(
self
.
cross_layer_shared_block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
return
DifferentialFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
True
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc_tensor
,
seq_start_loc
=
seq_start_loc_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
cross_layer_shared_block_tables
=
cross_layer_shared_block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
DifferentialFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
differential_flash_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
if
differential_flash_attention_config
is
None
:
differential_flash_attention_config
=
{}
self
.
differential_flash_attention_config
=
\
differential_flash_attention_config
self
.
used_shared_kv_cache
=
kv_sharing_target_layer_name
is
not
None
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
(
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
not
flash_attn_supports_fp8
()):
raise
NotImplementedError
(
f
"FlashAttention does not support
{
self
.
kv_cache_dtype
}
"
"kv-cache on this device "
f
"(FA supports fp8 =
{
flash_attn_supports_fp8
()
}
)."
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
self
.
lambda_full
=
None
self
.
subln
=
self
.
differential_flash_attention_config
[
"subln"
]
def
split_heads
(
self
,
x
):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
x
=
rearrange
(
x
,
"... (H two) D -> ... H two D"
,
two
=
2
)
x1
=
x
[...,
0
,
:]
x2
=
x
[...,
1
,
:]
return
x1
.
contiguous
(),
x2
.
contiguous
()
def
split_kv_cache
(
self
,
x
):
# split by num_heads, the stripe pattern is friendly to tensor parallel.
if
x
.
numel
()
==
0
:
return
torch
.
empty
(
0
),
torch
.
empty
(
0
)
x1
,
x2
=
x
[
0
],
x
[
1
]
return
x1
,
x2
def
populate_kv_cache
(
self
,
layer
:
AttentionLayer
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
):
if
kv_cache
.
numel
()
>
0
and
key
is
not
None
and
value
is
not
None
:
updated_slot_mapping
=
attn_metadata
.
slot_mapping
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
forward_generate_kv_cache
(
self
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
)
->
torch
.
Tensor
:
head_size
=
self
.
head_size
num_heads
=
self
.
num_heads
//
2
num_kv_heads
=
self
.
num_kv_heads
//
2
query
=
query
.
view
(
-
1
,
num_heads
,
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
num_kv_heads
,
head_size
)
value
=
value
.
view
(
-
1
,
num_kv_heads
,
head_size
)
else
:
assert
value
is
None
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
"key shape mismatch"
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
,
"value shape mismatch"
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
,
"query shape mismatch"
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
,
"decode query shape mismatch"
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
k_cache
.
numel
()
==
0
\
or
prefill_meta
.
block_tables
is
None
\
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# normal attention
prefill_output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
)
assert
prefill_output
.
shape
==
output
[:
num_prefill_tokens
].
shape
output
[:
num_prefill_tokens
]
=
prefill_output
else
:
raise
Exception
(
"prefix caching not supported"
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
block_tables_arg
=
decode_meta
.
block_tables
try
:
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
).
squeeze
(
1
)
except
Exception
as
e
:
logger
.
error
(
"Error in PagedAttention.forward_decode: %s"
,
str
(
e
))
raise
e
# Reshape the output tensor.
return
output
.
view
(
-
1
,
num_heads
,
head_size
)
def
forward_with_kv_cache_only
(
self
,
query
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
,
):
if
not
attn_metadata
.
decode_metadata
:
block_tables_arg
=
attn_metadata
.
cross_layer_shared_block_tables
else
:
block_tables_arg
=
attn_metadata
.
block_tables
output
=
flash_attn_with_kvcache
(
q
=
query
.
unsqueeze
(
1
),
k_cache
=
k_cache
,
v_cache
=
v_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
attn_metadata
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
fa_version
=
self
.
vllm_flash_attn_version
,
).
squeeze
(
1
)
return
output
def
forward
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
DifferentialFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
layer: Attention layer instance.
q: Query tensor with shape = [num_tokens, num_heads, head_size]
k: Key tensor with shape = [num_tokens, num_kv_heads, head_size]
v: Value tensor with shape = [num_tokens, num_kv_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Output tensor with shape [num_tokens, num_heads, head_size]
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl"
)
if
self
.
lambda_full
is
None
:
self
.
lambda_init
=
self
.
differential_flash_attention_config
[
"lambda_init"
]
lambda_q1
=
self
.
differential_flash_attention_config
[
"lambda_q1"
]
lambda_k1
=
self
.
differential_flash_attention_config
[
"lambda_k1"
]
lambda_q2
=
self
.
differential_flash_attention_config
[
"lambda_q2"
]
lambda_k2
=
self
.
differential_flash_attention_config
[
"lambda_k2"
]
lambda_1
=
torch
.
exp
(
torch
.
sum
(
lambda_q1
*
lambda_k1
,
dim
=-
1
).
float
()).
type_as
(
q
)
lambda_2
=
torch
.
exp
(
torch
.
sum
(
lambda_q2
*
lambda_k2
,
dim
=-
1
).
float
()).
type_as
(
q
)
self
.
lambda_full
=
lambda_1
-
lambda_2
+
self
.
lambda_init
if
not
self
.
used_shared_kv_cache
:
# need to generate kv-cache
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
v
=
v
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
q1
,
q2
=
self
.
split_heads
(
q
)
k1
,
k2
=
self
.
split_heads
(
k
)
v1
,
v2
=
self
.
split_heads
(
v
)
# kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501
# Split by half along the first dimension.
kv_cache1
,
kv_cache2
=
self
.
split_kv_cache
(
kv_cache
)
assert
kv_cache1
.
is_contiguous
(),
"kv_cache1 is not contiguous"
assert
kv_cache2
.
is_contiguous
(),
"kv_cache2 is not contiguous"
if
kv_cache1
.
numel
()
!=
0
:
self
.
populate_kv_cache
(
layer
,
k1
,
v1
,
kv_cache1
,
attn_metadata
)
self
.
populate_kv_cache
(
layer
,
k2
,
v2
,
kv_cache2
,
attn_metadata
)
key_cache1
,
value_cache1
=
self
.
split_kv_cache
(
kv_cache1
)
key_cache2
,
value_cache2
=
self
.
split_kv_cache
(
kv_cache2
)
else
:
key_cache1
,
value_cache1
=
torch
.
empty
(
0
),
torch
.
empty
(
0
)
key_cache2
,
value_cache2
=
torch
.
empty
(
0
),
torch
.
empty
(
0
)
attn11
=
self
.
forward_generate_kv_cache
(
q1
,
k1
,
v1
,
key_cache1
,
value_cache1
,
attn_metadata
)
attn12
=
self
.
forward_generate_kv_cache
(
q1
,
k1
,
v2
,
key_cache1
,
value_cache2
,
attn_metadata
)
attn11
=
attn11
.
view
(
q1
.
shape
)
attn12
=
attn12
.
view
(
q1
.
shape
)
attn1
=
torch
.
cat
([
attn11
,
attn12
],
dim
=-
1
)
attn21
=
self
.
forward_generate_kv_cache
(
q2
,
k2
,
v1
,
key_cache2
,
value_cache1
,
attn_metadata
)
attn22
=
self
.
forward_generate_kv_cache
(
q2
,
k2
,
v2
,
key_cache2
,
value_cache2
,
attn_metadata
)
attn21
=
attn21
.
view
(
q2
.
shape
)
attn22
=
attn22
.
view
(
q2
.
shape
)
attn2
=
torch
.
cat
([
attn21
,
attn22
],
dim
=-
1
)
attn
=
attn1
-
self
.
lambda_full
*
attn2
# attn shape (-1, self.num_heads // 2, 2 * self.head_dim)
attn
=
self
.
subln
(
attn
)
attn
=
attn
*
(
1
-
self
.
lambda_init
)
# reshape back to 2 * num_head
attn_output
=
rearrange
(
attn
,
"... H (two D) -> ... (H two) D"
,
two
=
2
)
else
:
# reuse the kv cache, full attention
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
q1
,
q2
=
self
.
split_heads
(
q
)
# kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501
kv_cache1
,
kv_cache2
=
self
.
split_kv_cache
(
kv_cache
)
key_cache1
,
value_cache1
=
kv_cache1
[
0
],
kv_cache1
[
1
]
key_cache2
,
value_cache2
=
kv_cache2
[
0
],
kv_cache2
[
1
]
attn11
=
self
.
forward_with_kv_cache_only
(
q1
,
key_cache1
,
value_cache1
,
attn_metadata
)
attn12
=
self
.
forward_with_kv_cache_only
(
q1
,
key_cache1
,
value_cache2
,
attn_metadata
)
attn11
=
attn11
.
view
(
q1
.
shape
)
attn12
=
attn12
.
view
(
q1
.
shape
)
attn1
=
torch
.
cat
([
attn11
,
attn12
],
dim
=-
1
)
attn21
=
self
.
forward_with_kv_cache_only
(
q2
,
key_cache2
,
value_cache1
,
attn_metadata
)
attn22
=
self
.
forward_with_kv_cache_only
(
q2
,
key_cache2
,
value_cache2
,
attn_metadata
)
attn21
=
attn21
.
view
(
q2
.
shape
)
attn22
=
attn22
.
view
(
q2
.
shape
)
attn2
=
torch
.
cat
([
attn21
,
attn22
],
dim
=-
1
)
attn
=
attn1
-
self
.
lambda_full
*
attn2
attn
=
self
.
subln
(
attn
)
attn
=
attn
*
(
1
-
self
.
lambda_init
)
# reshape back to 2 * num_head
attn_output
=
rearrange
(
attn
,
"... H (two D) -> ... (H two) D"
,
two
=
2
)
attn_output
=
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
return
attn_output
vllm/attention/backends/dual_chunk_flash_attn.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with Dual chunk flash attention and sparse attention.
"""
import
math
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch.distributed
import
torch.nn.functional
as
F
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
AttentionLayer
,
AttentionType
from
vllm.attention.backends.flash_attn
import
(
FlashAttentionBackend
,
FlashAttentionImpl
,
FlashAttentionMetadata
,
FlashAttentionMetadataBuilder
)
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.utils
import
async_tensor_h2d
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
logger
=
init_logger
(
__name__
)
class
DualChunkFlashAttentionBackend
(
FlashAttentionBackend
):
accept_output_buffer
:
bool
=
False
@
staticmethod
def
get_name
()
->
str
:
return
"DUAL_CHUNK_FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"DualChunkFlashAttentionImpl"
]:
return
DualChunkFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"DualChunkFlashAttentionMetadata"
]:
return
DualChunkFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"DualChunkFlashAttentionMetadataBuilder"
]:
return
DualChunkFlashAttentionMetadataBuilder
@
dataclass
class
DualChunkFlashAttentionMetadata
(
FlashAttentionMetadata
):
# Block size of the paged kv cache.
block_size
:
int
=
16
# Original max position embeddings.
original_max_position_embeddings
:
int
=
0
# Chunk size
chunk_size
:
int
=
8192
# Local size
local_size
:
int
=
1024
# (batch_size,). The orig sequence length per sequence.
orig_seq_lens
:
Optional
[
List
[
int
]]
=
None
# orig_seq_lens stored as a tensor.
orig_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Length scaling factor
scaling_factor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for intra attention.
seq_lens_intra
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for intra attention.
max_seq_len_intra
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for intra attention.
block_tables_intra
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for succ attention.
seq_lens_succ
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for succ attention.
max_seq_len_succ
:
Optional
[
int
]
=
None
# (batch_size, num_blocks). Block table for succ attention.
block_tables_succ
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,). Sequence lengths for inter attention.
seq_lens_inter
:
Optional
[
torch
.
Tensor
]
=
None
# Max sequence length for inter attention.
max_seq_len_inter
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"DualChunkFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"DualChunkFlashAttentionMetadata"
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"DualChunkFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
prefill_metadata
=
super
().
prefill_metadata
if
prefill_metadata
is
None
:
return
None
prefill_metadata
=
DualChunkFlashAttentionMetadata
(
**
prefill_metadata
.
asdict_zerocopy
())
prefill_metadata
.
orig_seq_lens
=
(
None
if
self
.
orig_seq_lens
is
None
else
self
.
orig_seq_lens
[:
self
.
num_prefills
])
prefill_metadata
.
orig_seq_lens_tensor
=
(
None
if
self
.
orig_seq_lens_tensor
is
None
else
self
.
orig_seq_lens_tensor
[:
self
.
num_prefills
])
if
self
.
original_max_position_embeddings
>
0
:
assert
prefill_metadata
.
orig_seq_lens_tensor
is
not
None
prefill_metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
prefill_metadata
.
orig_seq_lens_tensor
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
self
.
_cached_prefill_metadata
=
prefill_metadata
return
prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"DualChunkFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
decode_metadata
=
super
().
decode_metadata
if
decode_metadata
is
None
:
return
None
decode_metadata
=
DualChunkFlashAttentionMetadata
(
**
decode_metadata
.
asdict_zerocopy
())
decode_metadata
.
orig_seq_lens_tensor
=
(
None
if
self
.
orig_seq_lens_tensor
is
None
else
self
.
orig_seq_lens_tensor
[
self
.
num_prefills
:])
assert
decode_metadata
.
orig_seq_lens_tensor
is
not
None
assert
decode_metadata
.
block_tables
is
not
None
cache_seq_lens
=
decode_metadata
.
orig_seq_lens_tensor
chunk_len
=
self
.
chunk_size
-
self
.
local_size
chunk_num_curr
=
(
cache_seq_lens
-
1
)
//
chunk_len
batch_size
=
decode_metadata
.
num_decode_tokens
if
self
.
original_max_position_embeddings
>
0
:
decode_metadata
.
scaling_factor
=
(
0.1
*
torch
.
log
(
cache_seq_lens
/
self
.
original_max_position_embeddings
)
+
1.0
).
clip
(
min
=
1
)
seq_lens_intra
=
cache_seq_lens
-
chunk_num_curr
*
chunk_len
max_seq_len_intra
=
seq_lens_intra
.
max
().
item
()
decode_metadata
.
seq_lens_intra
=
seq_lens_intra
decode_metadata
.
max_seq_len_intra
=
max_seq_len_intra
block_tables_intra
=
torch
.
zeros
(
batch_size
,
(
max_seq_len_intra
-
1
)
//
self
.
block_size
+
1
,
dtype
=
decode_metadata
.
block_tables
.
dtype
,
device
=
decode_metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
st
=
chunk_num_curr
[
i
]
*
chunk_len
//
self
.
block_size
ed
=
min
(
st
+
(
max_seq_len_intra
-
1
)
//
self
.
block_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
block_size
+
1
,
)
block_tables_intra
[
i
,
:
ed
-
st
]
=
decode_metadata
.
block_tables
[
i
,
st
:
ed
]
decode_metadata
.
block_tables_intra
=
block_tables_intra
seq_lens_succ
=
(
chunk_num_curr
-
(
chunk_num_curr
-
1
).
clip
(
min
=
0
))
*
chunk_len
max_seq_len_succ
=
seq_lens_succ
.
max
().
item
()
decode_metadata
.
seq_lens_succ
=
seq_lens_succ
decode_metadata
.
max_seq_len_succ
=
max_seq_len_succ
if
max_seq_len_succ
:
block_tables_succ
=
torch
.
zeros
(
batch_size
,
(
max_seq_len_succ
-
1
)
//
self
.
block_size
+
1
,
dtype
=
decode_metadata
.
block_tables
.
dtype
,
device
=
decode_metadata
.
block_tables
.
device
,
)
for
i
in
range
(
batch_size
):
start
=
((
chunk_num_curr
[
i
]
-
1
).
clip
(
min
=
0
)
*
chunk_len
//
self
.
block_size
)
end
=
min
(
start
+
(
max_seq_len_succ
-
1
)
//
self
.
block_size
+
1
,
(
cache_seq_lens
[
i
]
-
1
)
//
self
.
block_size
+
1
,
)
block_tables_succ
[
i
,
:
end
-
start
]
=
decode_metadata
.
block_tables
[
i
,
start
:
end
]
decode_metadata
.
block_tables_succ
=
block_tables_succ
seq_lens_inter
=
(
chunk_num_curr
-
1
).
clip
(
min
=
0
)
*
chunk_len
max_seq_len_inter
=
seq_lens_inter
.
max
().
item
()
decode_metadata
.
seq_lens_inter
=
seq_lens_inter
decode_metadata
.
max_seq_len_inter
=
max_seq_len_inter
self
.
_cached_decode_metadata
=
decode_metadata
return
decode_metadata
class
DualChunkFlashAttentionMetadataBuilder
(
FlashAttentionMetadataBuilder
):
def
prepare
(
self
):
super
().
prepare
()
self
.
orig_seq_lens
:
List
[
int
]
=
[]
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
super
().
_add_seq_group
(
inter_data
,
chunked_prefill_enabled
,
prefix_cache_hit
)
for
prompt_len
,
seq_len
in
zip
(
inter_data
.
prompt_lens
,
inter_data
.
seq_lens
):
self
.
orig_seq_lens
.
append
(
max
(
prompt_len
,
seq_len
))
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
attn_metadata
=
super
().
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
attn_metadata
=
DualChunkFlashAttentionMetadata
(
**
attn_metadata
.
asdict_zerocopy
())
device
=
self
.
runner
.
device
attn_metadata
.
orig_seq_lens
=
self
.
orig_seq_lens
attn_metadata
.
orig_seq_lens_tensor
=
async_tensor_h2d
(
self
.
orig_seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
attn_metadata
.
block_size
=
self
.
runner
.
block_size
dual_chunk_attn_config
=
getattr
(
self
.
runner
.
model_config
.
hf_config
,
"dual_chunk_attention_config"
,
{})
attn_metadata
.
original_max_position_embeddings
=
\
dual_chunk_attn_config
.
get
(
"original_max_position_embeddings"
,
0
)
attn_metadata
.
chunk_size
=
dual_chunk_attn_config
.
get
(
"chunk_size"
,
8192
)
attn_metadata
.
local_size
=
dual_chunk_attn_config
.
get
(
"local_size"
,
1024
)
return
attn_metadata
class
DualChunkFlashAttentionImpl
(
FlashAttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
layer_idx
:
int
=
-
1
,
dual_chunk_attention_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"DUAL_CHUNK_FLASH_ATTN backend."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
sliding_window
is
not
None
:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise
ValueError
(
"Sliding window is not supported in FlashAttention."
)
support_head_sizes
=
(
DualChunkFlashAttentionBackend
.
get_supported_head_sizes
())
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
assert
dual_chunk_attention_config
is
not
None
self
.
chunk_size
=
dual_chunk_attention_config
.
get
(
"chunk_size"
,
8192
)
self
.
local_size
=
dual_chunk_attention_config
.
get
(
"local_size"
,
1024
)
self
.
original_max_position_embeddings
=
dual_chunk_attention_config
.
get
(
"original_max_position_embeddings"
,
0
)
self
.
sparse_attention_config
=
dual_chunk_attention_config
.
get
(
"sparse_attention_config"
,
None
)
if
not
self
.
sparse_attention_config
:
logger
.
warning_once
(
"Sparse attention will not be enabled as "
"sparse attention config is not provided."
)
self
.
sparse_attention_enabled
=
dual_chunk_attention_config
.
get
(
"sparse_attention_enabled"
,
self
.
sparse_attention_config
is
not
None
)
self
.
sparse_attention_threshold
=
dual_chunk_attention_config
.
get
(
"sparse_attention_threshold"
,
32768
)
self
.
sparse_attention_last_q
=
dual_chunk_attention_config
.
get
(
"sparse_attention_last_q"
,
64
)
self
.
layer_idx
=
layer_idx
self
.
dual_chunk_attention_config
=
dual_chunk_attention_config
if
self
.
sparse_attention_config
:
self
.
sparse_attention_config
=
{
int
(
i
):
j
for
i
,
j
in
self
.
sparse_attention_config
[
self
.
layer_idx
].
items
()
}
start_head
=
self
.
num_heads
*
get_tensor_model_parallel_rank
()
end_head
=
start_head
+
self
.
num_heads
self
.
sparse_attention_config
=
[
self
.
sparse_attention_config
[
i
]
for
i
in
range
(
start_head
,
end_head
)
]
if
self
.
sparse_attention_enabled
:
self
.
arange
=
torch
.
arange
(
self
.
sparse_attention_last_q
,
device
=
"cuda"
)
self
.
last_q_mask
=
(
self
.
arange
[
None
,
None
,
:,
None
]
>=
self
.
arange
[
None
,
None
,
None
,
:])
def
forward
(
# type: ignore
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
DualChunkFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with DualChunkFlashAttention.
Args:
query: shape = [num_tokens, num_heads * head_size]
query_succ: shape = [num_tokens, num_heads * head_size]
query_inter: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = [2, num_blocks, block_size, num_kv_heads * head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
None
,
"Output tensor not supported for DualChunk"
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashAttentionImpl"
)
(
query
,
query_succ
,
query_inter
,
query_succ_critical
,
query_inter_critical
,
)
=
torch
.
split
(
query
,
query
.
shape
[
-
1
]
//
5
,
dim
=-
1
)
assert
(
query_succ
is
not
None
and
query_inter
is
not
None
),
"query_succ and query_inter are required in Dual Chunk Attention."
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ
=
query_succ
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter
=
query_inter
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_succ_critical
=
query_succ_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query_inter_critical
=
query_inter_critical
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
original_max_position_embeddings
>
0
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
prefill_meta
.
scaling_factor
is
not
None
assert
prefill_meta
.
query_start_loc
is
not
None
assert
prefill_meta
.
orig_seq_lens
is
not
None
current_start
=
0
query_start_loc_cpu
=
prefill_meta
.
query_start_loc
.
cpu
()
for
i
in
range
(
len
(
prefill_meta
.
orig_seq_lens
)):
current_end
=
(
current_start
+
(
query_start_loc_cpu
[
i
+
1
]
-
query_start_loc_cpu
[
i
]).
item
())
key
[
current_start
:
current_end
].
mul_
(
prefill_meta
.
scaling_factor
[
i
])
current_start
=
current_end
assert
current_end
<=
attn_metadata
.
num_prefill_tokens
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
decode_meta
.
scaling_factor
is
not
None
scaling_factor
=
decode_meta
.
scaling_factor
key
[
attn_metadata
.
num_prefill_tokens
:].
mul_
(
scaling_factor
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
))
if
kv_cache
is
not
None
and
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query_succ
=
query_succ
[
num_prefill_tokens
:]
decode_query_inter
=
query_inter
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query_succ
=
query_succ
[:
num_prefill_tokens
]
query_inter
=
query_inter
[:
num_prefill_tokens
]
query_succ_critical
=
query_succ_critical
[:
num_prefill_tokens
]
query_inter_critical
=
query_inter_critical
[:
num_prefill_tokens
]
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
is
None
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention, called during the profiling run.
out
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
orig_seq_lens
is
not
None
output
[:
num_prefill_tokens
]
=
(
self
.
_dual_chunk_flash_attn_prefill
(
q
=
query
,
q_succ
=
query_succ
,
q_inter
=
query_inter
,
q_succ_critical
=
query_succ_critical
,
q_inter_critical
=
query_inter_critical
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
orig_seq_lens
=
prefill_meta
.
orig_seq_lens
,
scaling_factor
=
prefill_meta
.
scaling_factor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
(
-
1
,
-
1
),
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
))
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
output
[
num_prefill_tokens
:]
=
(
self
.
_dual_chunk_flash_attn_decoding
(
decode_query
.
unsqueeze
(
1
),
decode_query_succ
.
unsqueeze
(
1
),
decode_query_inter
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
chunk_size
=
self
.
chunk_size
,
local_size
=
self
.
local_size
,
original_max_position_embeddings
=
self
.
original_max_position_embeddings
,
decode_meta
=
decode_meta
,
).
squeeze
(
1
))
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
def
_dual_chunk_flash_attn_prefill
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
orig_seq_lens
:
List
[
int
],
scaling_factor
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
Optional
[
bool
]
=
True
,
window_size
:
Tuple
[
int
,
int
]
=
(
-
1
,
-
1
),
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
block_table
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
int
=
8192
,
local_size
:
int
=
1024
,
):
if
alibi_slopes
is
not
None
:
raise
ValueError
(
"Dual Chunk Attention does not support alibi_slopes"
)
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
if
window_size
!=
(
-
1
,
-
1
):
raise
ValueError
(
"Dual Chunk Attention does not support window_size"
)
cu_seqlens_q_cpu
=
cu_seqlens_q
.
cpu
().
tolist
()
cu_seqlens_k_cpu
=
cu_seqlens_k
.
cpu
().
tolist
()
all_outputs
=
[]
for
i
in
range
(
0
,
len
(
cu_seqlens_q_cpu
)
-
1
):
qs
=
cu_seqlens_q_cpu
[
i
]
qe
=
cu_seqlens_q_cpu
[
i
:
i
+
2
][
-
1
]
ks
=
cu_seqlens_k_cpu
[
i
]
ke
=
cu_seqlens_k_cpu
[
i
:
i
+
2
][
-
1
]
current_q
=
q
[
qs
:
qe
]
current_q_succ
=
q_succ
[
qs
:
qe
]
current_q_inter
=
q_inter
[
qs
:
qe
]
current_q_succ_critical
=
q_succ_critical
[
qs
:
qe
]
current_q_inter_critical
=
q_inter_critical
[
qs
:
qe
]
if
block_table
is
None
:
current_k
=
k
[
ks
:
ke
]
current_v
=
v
[
ks
:
ke
]
current_block_table
=
None
current_orig_seq_len
=
orig_seq_lens
[
i
]
else
:
current_block_table
=
block_table
[
i
]
current_orig_seq_len
=
orig_seq_lens
[
i
]
current_k
=
k
current_v
=
v
sparse_attn_enabled
=
(
self
.
sparse_attention_enabled
and
current_orig_seq_len
>
self
.
sparse_attention_threshold
)
if
current_q
.
shape
[
0
]
==
0
:
continue
if
current_k
.
shape
[
0
]
==
0
:
all_outputs
.
append
(
torch
.
zeros
(
(
current_q
.
shape
[
0
],
current_q
.
shape
[
1
],
v
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
q
.
dtype
,
))
continue
current_output
=
torch
.
empty_like
(
current_q
)
group_size
=
int
(
current_q
.
size
(
-
2
)
/
current_k
.
size
(
-
2
))
if
sparse_attn_enabled
:
num_device_q_heads
=
current_q
.
size
(
-
2
)
heads_vertical_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,
),
dtype
=
torch
.
int32
)
heads_slash_size
=
torch
.
empty
(
size
=
(
num_device_q_heads
,
),
dtype
=
torch
.
int32
)
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
(
ty
,
vertical_size
,
slash_size
,
_
,
)
=
self
.
sparse_attention_config
[
head_id
]
assert
ty
==
"vertical_and_slash"
,
"only support slash mode"
if
vertical_size
==
30
:
vertical_size
+=
100
heads_vertical_size
[
head_id
]
=
vertical_size
heads_slash_size
[
head_id
]
=
slash_size
current_output
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q
,
# allheads
current_q_succ
,
current_q_inter
,
current_q_succ_critical
,
current_q_inter_critical
,
current_k
,
current_v
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
heads_vertical_size
=
heads_vertical_size
,
heads_slash_size
=
heads_slash_size
,
group_size
=
group_size
)
else
:
for
head_id
in
range
(
current_q
.
size
(
-
2
)):
# (seq_len, num_heads, head_size)
current_q_head
=
current_q
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head
=
\
current_q_succ
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_inter_head
=
\
current_q_inter
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_succ_head_critical
=
\
current_q_succ_critical
[:,
head_id
,
:].
unsqueeze
(
1
)
current_q_inter_head_critical
=
\
current_q_inter_critical
[:,
head_id
,
:].
unsqueeze
(
1
)
if
block_table
is
not
None
:
current_k_head
=
current_k
[...,
head_id
//
group_size
,
:].
unsqueeze
(
2
)
current_v_head
=
current_v
[...,
head_id
//
group_size
,
:].
unsqueeze
(
2
)
else
:
current_k_head
=
current_k
[:,
head_id
,
:].
unsqueeze
(
1
)
current_v_head
=
current_v
[:,
head_id
,
:].
unsqueeze
(
1
)
current_out
=
self
.
_dual_chunk_flash_attn_prefill_func
(
current_q_head
,
current_q_succ_head
,
current_q_inter_head
,
current_q_succ_head_critical
,
current_q_inter_head_critical
,
current_k_head
,
current_v_head
,
current_block_table
,
softmax_scale
,
chunk_size
,
local_size
,
scaling_factor
[
i
].
item
(),
ke
-
ks
,
sparse_attn_enabled
=
sparse_attn_enabled
,
)
current_output
[:,
head_id
:
head_id
+
1
,
:]
=
current_out
all_outputs
.
append
(
current_output
)
return
torch
.
cat
(
all_outputs
,
dim
=
0
)
def
_dual_chunk_flash_attn_prefill_func
(
self
,
q
,
q_succ
,
q_inter
,
q_succ_critical
,
q_inter_critical
,
k
,
v
,
block_table
,
softmax_scale
:
float
,
chunk_size
:
int
,
local_size
:
int
,
scaling_factor
:
float
,
k_length
:
int
,
sparse_attn_enabled
:
Optional
[
bool
]
=
True
,
heads_vertical_size
=
None
,
heads_slash_size
=
None
,
group_size
=
None
,
):
flash_results
=
[]
chunk_len
=
chunk_size
-
local_size
if
block_table
is
not
None
:
block_size
=
v
.
shape
[
1
]
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
else
:
block_size
=
1
if
self
.
original_max_position_embeddings
>
0
:
softmax_scale
=
softmax_scale
*
scaling_factor
begin
=
k_length
-
q
.
shape
[
0
]
while
begin
<
k_length
:
flash_per_chunk
=
[]
prev_chunk_end_pos
=
(
begin
//
chunk_len
)
*
chunk_len
next_chunk_end_pos
=
prev_chunk_end_pos
+
chunk_len
end
=
min
(
next_chunk_end_pos
,
k_length
)
qbegin
=
begin
-
(
k_length
-
q
.
shape
[
0
])
qend
=
end
-
(
k_length
-
q
.
shape
[
0
])
qk_chunks
=
[]
q_states_intra
=
q
[
qbegin
:
qend
]
# choose critical token
if
block_table
is
not
None
:
block_tables_intra
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
,
end
)
k_states_intra
=
k
[
block_tables_intra
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[:(
end
-
prev_chunk_end_pos
)]
v_states_intra
=
v
[
block_tables_intra
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[:(
end
-
prev_chunk_end_pos
)]
else
:
block_tables_intra
=
None
k_states_intra
=
k
[
prev_chunk_end_pos
:
end
]
v_states_intra
=
v
[
prev_chunk_end_pos
:
end
]
if
sparse_attn_enabled
:
last_q_size
=
min
(
qend
-
qbegin
,
self
.
sparse_attention_last_q
)
_
,
num_device_k_heads
,
head_dim
=
k_states_intra
.
shape
k_states_intra
=
(
k_states_intra
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
v_states_intra
=
(
v_states_intra
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
qk_chunks
.
append
(
(
q_states_intra
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_intra
.
permute
(
1
,
2
,
0
))
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
q_states_succ
=
q_succ
[
qbegin
:
qend
]
q_states_succ_critical
=
q_succ_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_succ
=
_get_block
(
block_table
,
block_size
,
prev_chunk_end_pos
-
chunk_len
,
prev_chunk_end_pos
)
k_states_succ
=
k
[
block_tables_succ
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[:
chunk_len
]
v_states_succ
=
v
[
block_tables_succ
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[:
chunk_len
]
else
:
k_states_succ
=
k
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
v_states_succ
=
v
[
prev_chunk_end_pos
-
chunk_len
:
prev_chunk_end_pos
]
if
sparse_attn_enabled
:
k_states_succ
=
(
k_states_succ
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
v_states_succ
=
(
v_states_succ
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
qk_chunks
.
append
((
q_states_succ_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_succ
.
permute
(
1
,
2
,
0
))
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
q_states_inter
=
q_inter
[
qbegin
:
qend
]
q_states_inter_critical
=
q_inter_critical
[
qbegin
:
qend
]
if
block_table
is
not
None
:
block_tables_inter
=
_get_block
(
block_table
,
block_size
,
0
,
prev_chunk_end_pos
-
chunk_len
)
k_states_inter
=
k
[
block_tables_inter
].
view
(
-
1
,
*
k
.
shape
[
-
2
:])[:(
prev_chunk_end_pos
-
chunk_len
)]
v_states_inter
=
v
[
block_tables_inter
].
view
(
-
1
,
*
v
.
shape
[
-
2
:])[:(
prev_chunk_end_pos
-
chunk_len
)]
else
:
k_states_inter
=
k
[:
prev_chunk_end_pos
-
chunk_len
]
v_states_inter
=
v
[:
prev_chunk_end_pos
-
chunk_len
]
if
sparse_attn_enabled
:
k_states_inter
=
(
k_states_inter
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
v_states_inter
=
(
v_states_inter
.
unsqueeze
(
2
).
repeat
(
1
,
1
,
group_size
,
1
).
reshape
(
-
1
,
num_device_k_heads
*
group_size
,
head_dim
))
qk_chunks
.
append
((
q_states_inter_critical
.
transpose
(
0
,
1
)[:,
-
last_q_size
:]
*
softmax_scale
)
@
k_states_inter
.
permute
(
1
,
2
,
0
))
if
sparse_attn_enabled
:
reversed_qk
=
qk_chunks
[::
-
1
]
qk
=
torch
.
cat
(
reversed_qk
,
dim
=-
1
)
qk
[:,
:,
-
last_q_size
:]
=
torch
.
where
(
self
.
last_q_mask
[...,
-
last_q_size
:,
-
last_q_size
:].
to
(
qk
.
device
),
qk
[:,
:,
-
last_q_size
:],
-
torch
.
inf
)
qk
=
F
.
softmax
(
qk
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vertical
=
qk
.
sum
(
-
2
,
keepdim
=
True
)
vertical
[...,
:
30
]
=
torch
.
inf
# Avoid sorting by using the min/max ints to fill the indexer
# buffers.
int32_max
=
torch
.
iinfo
(
torch
.
int32
).
max
int32_min
=
torch
.
iinfo
(
torch
.
int32
).
min
n_heads
=
qk
.
size
()[
0
]
max_slash_topk
=
torch
.
max
(
heads_slash_size
).
item
()
max_vertical_topk
=
torch
.
max
(
heads_vertical_size
).
item
()
# store each head's slash topk, vertical topk
vertical
=
vertical
.
reshape
((
n_heads
,
-
1
))
# prevent out of range when prompt size < max_vertical_topk
max_vertical_topk
=
min
(
vertical
.
shape
[
-
1
],
max_vertical_topk
)
vertical_topk_buffer
=
torch
.
topk
(
vertical
,
max_vertical_topk
,
-
1
).
indices
slash_topk_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
max_slash_topk
),
dtype
=
torch
.
int64
,
device
=
qk
.
device
)
for
head_i
in
range
(
n_heads
):
# (nqheads=1, lastq, k_len)
head_score
=
qk
[
head_i
:
head_i
+
1
,
:,
:]
slash_scores
=
_sum_all_diagonal_matrix
(
head_score
)
if
head_score
.
size
(
1
)
!=
1
:
# drop right up corner
slash_scores
=
slash_scores
[...,
:
-
last_q_size
+
1
]
slash_scores
[...,
-
100
:]
=
torch
.
inf
head_slash_size
=
heads_slash_size
[
head_i
]
head_slash_size
=
min
(
head_slash_size
,
vertical
.
size
(
-
1
))
slash_topk
=
torch
.
topk
(
slash_scores
,
head_slash_size
,
-
1
).
indices
#(nheads, max_topk)
slash_topk_buffer
[
head_i
,
:
head_slash_size
]
=
slash_topk
# reset heads topk
heads_slash_size
[
head_i
]
=
head_slash_size
heads_vertical_size
[
head_i
]
=
min
(
heads_vertical_size
[
head_i
],
max_vertical_topk
)
# store
vertical_buffer
=
torch
.
full
((
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
slash_buffer
=
torch
.
full
((
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
succ_vertical_buffer
=
torch
.
full
((
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
succ_slash_buffer
=
torch
.
full
((
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
inter_vertical_buffer
=
torch
.
full
(
(
n_heads
,
max_vertical_topk
),
int32_max
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
inter_slash_buffer
=
torch
.
full
((
n_heads
,
max_slash_topk
),
int32_min
,
dtype
=
torch
.
int64
,
device
=
q
.
device
)
vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
succ_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_vertical_size_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
inter_slash_sizes_buffer
=
torch
.
empty
(
size
=
(
n_heads
,
),
dtype
=
torch
.
int32
,
device
=
q
.
device
)
for
head_i
in
range
(
n_heads
):
vertical_topk
=
vertical_topk_buffer
[
head_i
,
:
heads_vertical_size
[
head_i
]]
# intra
intra_vertical_indices
=
vertical_topk
[
vertical_topk
>=
prev_chunk_end_pos
]
-
prev_chunk_end_pos
if
intra_vertical_indices
.
nelement
()
==
0
:
intra_vertical_indices
=
torch
.
cat
([
intra_vertical_indices
,
torch
.
arange
(
0
,
k_states_intra
.
size
(
0
),
max
(
1
,
k_states_intra
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
)
])
slash_topk
=
slash_topk_buffer
[
head_i
,
:
heads_slash_size
[
head_i
]]
intra_slash_indices
=
(
(
qk
.
size
(
-
1
)
-
1
)
-
slash_topk
[
slash_topk
>=
prev_chunk_end_pos
])
# fill buffer
v_count
=
intra_vertical_indices
.
nelement
()
s_count
=
intra_slash_indices
.
nelement
()
vertical_size_buffer
[
head_i
]
=
v_count
slash_sizes_buffer
[
head_i
]
=
s_count
vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
intra_vertical_indices
)
slash_buffer
[
head_i
,
:
s_count
].
copy_
(
intra_slash_indices
)
# succ
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
succ_vertical_indices
=
vertical_topk
[
(
vertical_topk
<
prev_chunk_end_pos
)
&
(
vertical_topk
>=
prev_chunk_end_pos
-
chunk_len
)]
-
(
prev_chunk_end_pos
-
chunk_len
)
# TODO: support no vertical
if
succ_vertical_indices
.
nelement
()
==
0
:
succ_vertical_indices
=
torch
.
cat
([
succ_vertical_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
)
])
succ_slash_indices
=
(
(
prev_chunk_end_pos
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[((
slash_topk
>=
(
prev_chunk_end_pos
-
chunk_len
))
&
(
slash_topk
<
(
prev_chunk_end_pos
+
(
qend
-
qbegin
))))])
if
succ_slash_indices
.
nelement
()
==
0
:
succ_slash_indices
=
torch
.
cat
([
succ_slash_indices
,
torch
.
arange
(
0
,
k_states_succ
.
size
(
0
),
max
(
1
,
k_states_succ
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
)
])
# fill buffer
v_count
=
succ_vertical_indices
.
nelement
()
s_count
=
succ_slash_indices
.
nelement
()
succ_vertical_size_buffer
[
head_i
]
=
v_count
succ_slash_sizes_buffer
[
head_i
]
=
s_count
succ_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
succ_vertical_indices
)
succ_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
succ_slash_indices
)
if
prev_chunk_end_pos
-
2
*
chunk_len
>=
0
:
inter_vertical_indices
=
vertical_topk
[
vertical_topk
<
prev_chunk_end_pos
-
chunk_len
]
if
inter_vertical_indices
.
nelement
()
==
0
:
inter_vertical_indices
=
torch
.
cat
([
inter_vertical_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
)
])
inter_slash_indices
=
(
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
)
-
1
)
-
slash_topk
[
slash_topk
<
(
prev_chunk_end_pos
-
chunk_len
+
(
qend
-
qbegin
))])
if
inter_slash_indices
.
nelement
()
==
0
:
inter_slash_indices
=
torch
.
cat
([
inter_slash_indices
,
torch
.
arange
(
0
,
k_states_inter
.
size
(
0
),
max
(
1
,
k_states_inter
.
size
(
0
)
/
5
),
dtype
=
torch
.
int32
,
device
=
intra_vertical_indices
.
device
)
])
# fill buffer
v_count
=
inter_vertical_indices
.
nelement
()
s_count
=
inter_slash_indices
.
nelement
()
inter_vertical_size_buffer
[
head_i
]
=
v_count
inter_slash_sizes_buffer
[
head_i
]
=
s_count
inter_vertical_buffer
[
head_i
,
:
v_count
].
copy_
(
inter_vertical_indices
)
inter_slash_buffer
[
head_i
,
:
s_count
].
copy_
(
inter_slash_indices
)
else
:
intra_vertical_indices
,
intra_slash_indices
=
None
,
None
succ_vertical_indices
,
succ_slash_indices
=
None
,
None
inter_vertical_indices
,
inter_slash_indices
=
None
,
None
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
vertical_buffer
,
slash_indices
=
slash_buffer
,
vertical_indices_count
=
vertical_size_buffer
,
slash_indices_count
=
slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_intra
,
k_states_intra
,
v_states_intra
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
stage
=
"intra"
,
vertical_indices
=
intra_vertical_indices
,
slash_indices
=
intra_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_buffer
,
slash_indices
=
succ_slash_buffer
,
vertical_indices_count
=
succ_vertical_size_buffer
,
slash_indices_count
=
succ_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_succ
,
k_states_succ
,
v_states_succ
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"succ"
,
vertical_indices
=
succ_vertical_indices
,
slash_indices
=
succ_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
)
flash_per_chunk
.
append
(
flash_result
)
if
prev_chunk_end_pos
-
chunk_len
*
2
>=
0
:
if
sparse_attn_enabled
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_buffer
,
slash_indices
=
inter_slash_buffer
,
vertical_indices_count
=
inter_vertical_size_buffer
,
slash_indices_count
=
inter_slash_sizes_buffer
,
mergehead_softmax_scale
=
softmax_scale
,
sparse_attn_enabled
=
sparse_attn_enabled
)
else
:
flash_result
=
self
.
_do_flash_attn
(
q_states_inter
,
k_states_inter
,
v_states_inter
,
softmax_scale
=
softmax_scale
,
causal
=
False
,
stage
=
"inter"
,
vertical_indices
=
inter_vertical_indices
,
slash_indices
=
inter_slash_indices
,
sparse_attn_enabled
=
sparse_attn_enabled
)
flash_per_chunk
.
append
(
flash_result
)
flash_results
.
append
(
flash_per_chunk
)
begin
=
end
attn_output
=
self
.
_merge_attn_outputs
(
flash_results
)
del
flash_results
return
attn_output
def
_do_flash_attn
(
self
,
query_states
:
torch
.
Tensor
,
key_states
:
torch
.
Tensor
,
value_states
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
=
True
,
max_seqlen_k
:
Optional
[
int
]
=
None
,
stage
:
str
=
"intra"
,
vertical_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
vertical_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
slash_indices_count
:
Optional
[
torch
.
Tensor
]
=
None
,
mergehead_softmax_scale
:
Optional
[
float
]
=
None
,
sparse_attn_enabled
:
Optional
[
bool
]
=
False
,
):
if
max_seqlen_k
is
None
:
max_seqlen_k
=
key_states
.
shape
[
0
]
q_len
=
query_states
.
shape
[
0
]
q_heads
=
query_states
.
shape
[
1
]
h_dim
=
query_states
.
shape
[
-
1
]
if
sparse_attn_enabled
:
assert
slash_indices
is
not
None
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
query_states
=
query_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
key_states
=
key_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
value_states
=
value_states
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
q
=
query_states
k
=
key_states
v
=
value_states
if
(
vertical_indices_count
is
not
None
and
\
slash_indices_count
is
not
None
):
assert
mergehead_softmax_scale
is
not
None
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
mergehead_softmax_scale
,
causal
=
causal
,
stage
=
stage
,
vertical_indices_count
=
vertical_indices_count
,
slash_indices_count
=
slash_indices_count
)
res
=
res
.
view
(
q_heads
,
q_len
,
h_dim
).
transpose
(
0
,
1
)
# (qlen,nhead,h_dim)
s_lse
=
s_lse
.
view
(
q_heads
,
q_len
,
1
).
squeeze
(
-
1
).
unsqueeze
(
0
).
float
()
# (1, nhead,qlen)
else
:
res
,
s_lse
=
_vertical_slash_sparse_attention
(
q
,
k
,
v
,
vertical_indices
,
slash_indices
,
softmax_scale
,
causal
=
causal
,
stage
=
stage
)
res
=
res
.
view
(
q_len
,
q_heads
,
h_dim
)
s_lse
=
s_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
res
,
s_lse
output
,
softmax_lse
=
flash_attn_varlen_func
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
softmax_scale
=
softmax_scale
,
cu_seqlens_q
=
torch
.
tensor
([
0
,
query_states
.
shape
[
0
]],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_q
=
query_states
.
shape
[
0
],
cu_seqlens_k
=
torch
.
tensor
([
0
,
max_seqlen_k
],
dtype
=
torch
.
int32
,
device
=
query_states
.
device
),
max_seqlen_k
=
max_seqlen_k
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
softmax_lse
=
softmax_lse
.
view
(
q_len
,
q_heads
,
1
).
transpose
(
0
,
2
).
float
()
return
output
,
softmax_lse
def
_merge_attn_outputs
(
self
,
flash_results
:
List
[
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]],
return_lse
:
Optional
[
bool
]
=
False
,
)
->
torch
.
Tensor
:
attn_outputs_all
=
[]
logits_all
=
[]
for
flash_per_chunk
in
flash_results
:
if
len
(
flash_per_chunk
)
==
1
:
attn_outputs_all
.
append
(
flash_per_chunk
[
0
][
0
])
if
return_lse
:
logits_all
.
append
(
flash_per_chunk
[
0
][
1
])
continue
attn_outputs
=
torch
.
stack
([
flash_attn_output
[
0
]
for
flash_attn_output
in
flash_per_chunk
])
logits
=
torch
.
stack
([
flash_attn_output
[
1
]
for
flash_attn_output
in
flash_per_chunk
])
logits
=
logits
.
to
(
torch
.
float32
)
if
return_lse
:
max_val
=
torch
.
max
(
logits
,
dim
=
0
).
values
diff
=
torch
.
abs
(
logits
[
0
]
-
logits
[
1
])
log_sum_exp
=
max_val
+
torch
.
log1p
(
torch
.
exp
(
-
diff
))
logits_all
.
append
(
log_sum_exp
)
max_logits
=
torch
.
max
(
logits
,
dim
=
0
).
values
stable_logits
=
logits
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
attn_outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
).
squeeze
(
1
)
attn_outputs_all
.
append
(
attn_outputs
.
sum
(
dim
=
0
))
if
return_lse
:
return
(
torch
.
cat
(
attn_outputs_all
,
dim
=
0
),
torch
.
cat
(
logits_all
,
dim
=-
1
))
else
:
return
torch
.
cat
(
attn_outputs_all
,
dim
=
0
)
def
_dual_chunk_flash_attn_decoding
(
self
,
query
:
torch
.
Tensor
,
query_succ
:
torch
.
Tensor
,
query_inter
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
causal
:
bool
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
chunk_size
:
int
,
local_size
:
int
,
original_max_position_embeddings
:
int
,
decode_meta
:
DualChunkFlashAttentionMetadata
,
):
if
not
causal
:
raise
ValueError
(
"Dual Chunk Attention does not support causal=False"
)
block_size
=
value_cache
.
shape
[
1
]
chunk_len
=
chunk_size
-
local_size
if
chunk_len
%
block_size
!=
0
:
raise
ValueError
(
"chunk_len must be divisible by block_size."
)
if
original_max_position_embeddings
>
0
:
assert
decode_meta
.
scaling_factor
is
not
None
scaling_factor
=
decode_meta
.
scaling_factor
query
=
(
query
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
# possible for numerical issue, need to fused in the kernel
query_succ
=
(
query_succ
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
query_inter
=
(
query_inter
*
scaling_factor
.
view
(
-
1
,
1
,
1
,
1
)).
to
(
query
.
dtype
)
outputs_list
=
[]
softmax_lses_list
=
[]
# intra-attention
intra_output
,
intra_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_intra
,
decode_meta
.
seq_lens_intra
,
softmax_scale
,
alibi_slopes
,
causal
=
False
,
))
outputs_list
.
append
(
intra_output
)
softmax_lses_list
.
append
(
intra_softmax_lse
)
# succ-attention
if
decode_meta
.
max_seq_len_succ
:
succ_output
,
succ_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_succ
,
key_cache
,
value_cache
,
decode_meta
.
block_tables_succ
,
decode_meta
.
seq_lens_succ
,
softmax_scale
,
alibi_slopes
,
causal
=
False
,
))
outputs_list
.
append
(
succ_output
)
softmax_lses_list
.
append
(
succ_softmax_lse
)
# inter-attention
if
decode_meta
.
max_seq_len_inter
:
inter_output
,
inter_softmax_lse
=
(
self
.
_dual_chunk_flash_attn_decoding_with_exp_sums
(
query_inter
,
key_cache
,
value_cache
,
block_table
[:,
:
decode_meta
.
max_seq_len_inter
],
decode_meta
.
seq_lens_inter
,
softmax_scale
,
alibi_slopes
,
causal
=
False
,
))
outputs_list
.
append
(
inter_output
)
softmax_lses_list
.
append
(
inter_softmax_lse
)
outputs
=
torch
.
stack
(
outputs_list
,
dim
=
0
)
del
outputs_list
softmax_lses
=
torch
.
stack
(
softmax_lses_list
,
dim
=
0
).
to
(
torch
.
float32
)
del
softmax_lses_list
max_logits
=
torch
.
max
(
softmax_lses
,
dim
=
0
).
values
stable_logits
=
softmax_lses
-
max_logits
.
unsqueeze
(
0
)
lse_s
=
torch
.
exp
(
stable_logits
).
detach
()
lse_sum
=
torch
.
sum
(
lse_s
,
dim
=
0
)
lse_s
/=
lse_sum
outputs
*=
lse_s
.
unsqueeze
(
-
1
).
transpose
(
2
,
3
)
return
outputs
.
sum
(
0
)
def
_dual_chunk_flash_attn_decoding_with_exp_sums
(
self
,
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_table
:
torch
.
Tensor
,
cache_seqlens
:
torch
.
Tensor
,
softmax_scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
causal
:
bool
,
):
out
,
softmax_lse
=
flash_attn_with_kvcache
(
q
=
query
,
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_table
,
cache_seqlens
=
cache_seqlens
,
softmax_scale
=
softmax_scale
,
alibi_slopes
=
alibi_slopes
,
causal
=
causal
,
return_softmax_lse
=
True
,
)
mask
=
(
cache_seqlens
==
0
)
out
[
mask
]
=
0
softmax_lse
[
mask
]
=
-
float
(
"inf"
)
return
out
,
softmax_lse
def
_vertical_slash_sparse_attention
(
query
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_CTX, D_HEAD]
key
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
value
:
torch
.
Tensor
,
# [BATCH, N_HEADS, N_KV_CTX, D_HEAD]
v_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_V]
s_idx
:
torch
.
Tensor
,
# [BATCH, N_HEADS, NNZ_S]
softmax_scale
:
float
,
causal
:
bool
=
True
,
stage
:
str
=
"intra"
,
block_size_M
:
int
=
64
,
block_size_N
:
int
=
64
,
vertical_indices_count
:
torch
.
Tensor
=
None
,
# [N_HEADS,]
slash_indices_count
:
torch
.
Tensor
=
None
,
):
if
stage
==
"intra"
:
assert
causal
else
:
assert
not
causal
batch_size
,
num_heads
,
context_size
,
head_dim
=
query
.
shape
_
,
_
,
kv_seq_len
,
_
=
key
.
shape
if
head_dim
not
in
[
16
,
32
,
64
,
128
,
256
,
512
]:
target_dim
=
2
**
math
.
ceil
(
math
.
log2
(
head_dim
))
-
head_dim
query
=
F
.
pad
(
query
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
key
=
F
.
pad
(
key
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
value
=
F
.
pad
(
value
,
[
0
,
target_dim
,
0
,
0
,
0
,
0
,
0
,
0
])
v_idx
=
v_idx
.
to
(
torch
.
int32
).
reshape
(
(
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
False
)[
0
]
s_idx
=
s_idx
.
to
(
torch
.
int32
).
reshape
(
(
batch_size
,
num_heads
,
-
1
)).
sort
(
dim
=-
1
,
descending
=
True
)[
0
]
q_seqlens
=
torch
.
tensor
([
context_size
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
kv_seqlens
=
torch
.
tensor
([
kv_seq_len
],
dtype
=
torch
.
int32
,
device
=
query
.
device
)
if
vertical_indices_count
is
not
None
and
slash_indices_count
is
not
None
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
ops
.
convert_vertical_slash_indexes_mergehead
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
vertical_indices_count
,
slash_indices_count
,
context_size
,
block_size_M
,
block_size_N
,
causal
)
else
:
(
block_count
,
block_offset
,
column_count
,
column_index
,
)
=
ops
.
convert_vertical_slash_indexes
(
q_seqlens
,
kv_seqlens
,
v_idx
,
s_idx
,
context_size
,
block_size_M
,
block_size_N
,
causal
)
q
=
query
.
transpose
(
1
,
2
).
contiguous
()
k
=
key
.
transpose
(
1
,
2
).
contiguous
()
v
=
value
.
transpose
(
1
,
2
).
contiguous
()
out
,
lse
=
sparse_attn_func
(
q
,
k
,
v
,
block_count
,
block_offset
,
column_count
,
column_index
,
causal
=
causal
,
softmax_scale
=
softmax_scale
,
return_softmax_lse
=
True
,
)
out
=
out
.
transpose
(
1
,
2
).
contiguous
()
softmax_lse
=
lse
.
reshape
(
*
lse
.
shape
,
1
)
return
(
out
[...,
:
context_size
,
:
head_dim
],
softmax_lse
[...,
:
context_size
,
:])
def
_sum_all_diagonal_matrix
(
mat
:
torch
.
tensor
):
h
,
n
,
m
=
mat
.
shape
# Zero matrix used for padding
zero_mat
=
torch
.
zeros
((
h
,
n
,
n
),
device
=
mat
.
device
)
# pads the matrix on left and right
mat_padded
=
torch
.
cat
((
zero_mat
,
mat
,
zero_mat
),
-
1
)
# Change the strides
mat_strided
=
mat_padded
.
as_strided
((
1
,
n
,
n
+
m
),
(
n
*
(
2
*
n
+
m
),
2
*
n
+
m
+
1
,
1
))
# Sums the resulting matrix's columns
sum_diags
=
torch
.
sum
(
mat_strided
,
1
)
return
sum_diags
[:,
1
:]
# drop left bottom corner
def
_get_block
(
block_table
:
torch
.
Tensor
,
block_size
:
int
,
begin
:
int
,
end
:
int
):
begin_block
=
begin
//
block_size
end_block
=
(
end
-
1
)
//
block_size
+
1
return
block_table
[
begin_block
:
end_block
]
vllm/attention/backends/flash_attn.py
deleted
100755 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm
import
_custom_ops
as
ops
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
# yapf: enable
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
,
is_block_tables_empty
)
from
vllm.attention.utils.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
)
logger
=
init_logger
(
__name__
)
class
FlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
32
,
64
,
96
,
128
,
160
,
192
,
224
,
256
]
@
staticmethod
def
get_name
()
->
str
:
return
"FLASH_ATTN"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashAttentionImpl"
]:
return
FlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
FlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashAttentionMetadataBuilder"
]:
return
FlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
@
dataclass
class
FlashAttentionMetadata
(
AttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"FlashAttentionMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
is_all_cross_attn_metadata_set
(
self
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
((
self
.
seq_lens
is
not
None
)
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
self
.
_cached_prefill_metadata
=
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"FlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
FlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
encoder_seq_start_loc
=
self
.
encoder_seq_start_loc
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_decode_metadata
class
FlashAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
FlashAttentionMetadata
]):
def
__init__
(
self
,
input_builder
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
mm_maps
=
inter_data
.
multi_modal_placeholder_maps
if
mm_maps
:
for
modality
,
placeholders
in
mm_maps
.
items
():
self
.
multimodal_placeholder_maps
[
modality
].
extend
(
placeholders
)
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
_get_graph_runner_block_tables
(
self
,
num_seqs
:
int
,
block_tables
:
List
[
List
[
int
]])
->
torch
.
Tensor
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size
,
max_blocks
=
self
.
runner
.
graph_block_tables
.
shape
assert
max_batch_size
>=
num_seqs
graph_block_tables
=
self
.
runner
.
graph_block_tables
[:
num_seqs
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
graph_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
return
torch
.
from_numpy
(
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
,
non_blocking
=
True
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
max_decode_query_len
=
max
(
decode_query_lens
)
else
:
max_decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
([]
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
self
.
multimodal_placeholder_maps
.
items
()
}
return
FlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
seq_lens
=
seq_lens
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
True
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc_tensor
,
seq_start_loc
=
seq_start_loc_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
class
FlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"FLASH_ATTN backend."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
vllm_flash_attn_version
=
get_flash_attn_version
(
requires_alibi
=
self
.
alibi_slopes
is
not
None
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
(
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
or
not
flash_attn_supports_fp8
()):
raise
NotImplementedError
(
f
"FlashAttention does not support
{
self
.
kv_cache_dtype
}
"
"kv-cache on this device "
f
"(FA supports fp8 =
{
flash_attn_supports_fp8
()
}
)."
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
self
.
logits_soft_cap
=
logits_soft_cap
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by FlashAttention. "
f
"Supported head sizes are:
{
support_head_sizes
}
."
)
self
.
attn_type
=
attn_type
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
Args:
query: shape = [num_tokens, num_heads, head_size]
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
output: shape = [num_tokens, num_heads, head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size, num_kv_heads, head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
NOTE: It in-place updates the output tensor.
NOTE: FP8 quantization, flash-attn expect the size of
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for FlashAttentionImpl"
)
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if
not
flash_attn_supports_fp8
()
or
output
.
dtype
!=
torch
.
bfloat16
:
assert
(
layer
.
_k_scale_float
==
1.0
and
layer
.
_v_scale_float
==
1.0
),
(
"key/v_scale is only supported in FlashAttention 3 with "
"base dtype bfloat16"
)
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
kv_cache_dtype
:
str
=
self
.
kv_cache_dtype
softmax_scale
:
float
=
self
.
scale
window_size
=
self
.
sliding_window
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
self
.
alibi_slopes
logits_soft_cap
:
Optional
[
float
]
=
self
.
logits_soft_cap
fp8_attention
=
kv_cache_dtype
.
startswith
(
"fp8"
)
if
fp8_attention
and
not
flash_attn_supports_fp8
():
raise
NotImplementedError
(
"FlashAttention does not support FP8 kv-cache on this device."
)
if
kv_cache
.
numel
()
>
0
:
key_cache
=
kv_cache
[
0
]
value_cache
=
kv_cache
[
1
]
# We skip updating the KV cache under two conditions:
# a. When the Attention Type is ENCODER. In this phase, we compute
# only the encoder attention without updating the cache.
# b. When both Key and Value are None. This occurs during
# cross-attention computation in the decoding phase, where the
# KV cache is already populated with the cross-attention
# tensor. Thus, we skip cache updates during this time.
if
(
attn_type
!=
AttentionType
.
ENCODER
)
and
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[
0
],
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
# type: ignore[union-attr]
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
fp8_attention
:
kv_cache
=
kv_cache
.
view
(
torch
.
float8_e4m3fn
)
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
if
fp8_attention
:
num_tokens
,
num_heads
,
head_size
=
query
.
shape
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
decode_query
=
query
[
num_prefill_query_tokens
:]
decode_output
=
output
[
num_prefill_query_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_query_tokens
]
prefill_output
=
output
[:
num_prefill_query_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
q_seq_start_loc
,
q_seq_len
,
k_seq_start_loc
,
k_seq_len
=
\
_get_query_key_seq_metadata
(
prefill_meta
,
True
,
attn_type
)
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
if
fp8_attention
:
num_kv_tokens
,
num_kv_heads
,
head_size
=
key
.
shape
key
,
_
=
ops
.
scaled_fp8_quant
(
key
.
reshape
((
num_kv_tokens
,
num_kv_heads
*
head_size
)).
contiguous
(),
layer
.
_k_scale
)
key
=
key
.
reshape
((
num_kv_tokens
,
num_kv_heads
,
head_size
))
value
,
_
=
ops
.
scaled_fp8_quant
(
value
.
reshape
((
num_kv_tokens
,
num_kv_heads
*
head_size
)).
contiguous
(),
layer
.
_v_scale
)
value
=
value
.
reshape
(
(
num_kv_tokens
,
num_kv_heads
,
head_size
))
descale_shape
=
(
q_seq_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
q_seq_start_loc
,
cu_seqlens_k
=
k_seq_start_loc
,
max_seqlen_q
=
q_seq_len
,
max_seqlen_k
=
k_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
_get_causal_option
(
attn_type
),
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
# prefix-enabled attention
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support prefix caching"
)
assert
prefill_meta
.
seq_lens
is
not
None
assert
prefill_meta
.
query_start_loc
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
descale_shape
=
(
prefill_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
# noqa
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
prefill_meta
.
query_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_query_len
,
seqused_k
=
prefill_meta
.
seq_lens_tensor
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
out
=
prefill_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
# Use flash_attn_varlen_func kernel for speculative decoding
# because different queries might have different lengths.
assert
decode_meta
.
max_decode_query_len
is
not
None
# use only for actual varlen decoding
if
decode_meta
.
max_decode_query_len
>
1
:
assert
attn_type
==
AttentionType
.
DECODER
,
(
"Only decoder-only models support max_decode_query_len > 1"
)
assert
decode_meta
.
query_start_loc
is
not
None
descale_shape
=
(
decode_meta
.
query_start_loc
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
q
=
decode_query
,
k
=
key_cache
,
v
=
value_cache
,
cu_seqlens_q
=
decode_meta
.
query_start_loc
,
max_seqlen_q
=
decode_meta
.
max_decode_query_len
,
seqused_k
=
decode_meta
.
seq_lens_tensor
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
out
=
decode_output
,
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
else
:
# Use flash_attn_with_kvcache for normal decoding.
(
seq_lens_arg
,
_
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
descale_shape
=
(
seq_lens_arg
.
shape
[
0
],
key_cache
.
shape
[
-
2
])
flash_attn_with_kvcache
(
q
=
decode_query
.
unsqueeze
(
1
),
k_cache
=
key_cache
,
v_cache
=
value_cache
,
block_table
=
block_tables_arg
,
cache_seqlens
=
seq_lens_arg
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
out
=
decode_output
.
unsqueeze
(
1
),
fa_version
=
self
.
vllm_flash_attn_version
,
q_descale
=
layer
.
_q_scale
.
expand
(
descale_shape
),
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
return
output
def
_get_query_key_seq_metadata
(
attn_metadata
:
FlashAttentionMetadata
,
is_prompt
:
bool
,
attn_type
:
str
,
)
->
tuple
:
"""
Returns sequence metadata for key and query based on the specified
attention type and whether input is a prompt.
This function computes the starting locations and maximum sequence lengths
for key and query sequences for different attention types.
Args:
attn_metadata: The attention metadata object
is_prompt (bool): A flag indicating if the input is a prompt
attn_type (AttentionType): The type of attention being used.
Returns:
tuple: A tuple containing four integers:
- Starting location for the query sequence.
- Maximum sequence length for the query sequence.
- Starting location for the key sequence.
- Maximum sequence length for the key sequence.
Raises:
AttributeError: If an invalid attention type is provided.
"""
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_start_loc
,
max_seq_len
,
attn_metadata
.
seq_start_loc
,
max_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# This is cross attention between the where the key
# is the precomputed encoder attention and query
# is the input sequence.
# Choose query max length based on whether it is prompt
# or not.
if
is_prompt
:
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
else
:
max_seq_len
=
attn_metadata
.
max_decode_seq_len
return
(
attn_metadata
.
seq_start_loc
,
max_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER
:
# For encoder attention both the query and the key are same i.e. the
# encoder sequence.
return
(
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
encoder_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
assert
is_prompt
,
"Should not have decode for encoder only model."
return
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
max_prefill_seq_len
,
attn_metadata
.
seq_start_loc
,
attn_metadata
.
max_prefill_seq_len
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_get_causal_option
(
attn_type
:
str
)
->
bool
:
"""
Determine whether the given attention type is suitable for causal
attention mechanisms.
Args:
attn_type (AttentionType): The type of attention being evaluated
Returns:
bool: Returns `True` if the attention type is suitable for causal
attention (i.e., not encoder, encoder-only, or encoder-decoder),
otherwise returns `False`.
"""
return
not
(
attn_type
==
AttentionType
.
ENCODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
or
attn_type
==
AttentionType
.
ENCODER_DECODER
)
vllm/attention/backends/flashmla.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonState
)
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_supported
)
class
FlashMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"FLASHMLA"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashMLAImpl"
]:
return
FlashMLAImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"FlashMLAMetadata"
]:
return
FlashMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"FlashMLAMetadataBuilder"
]:
return
FlashMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"FlashMLAState"
]:
return
FlashMLAState
@
dataclass
class
FlashMLAMetadata
(
MLACommonMetadata
):
decode_tile_scheduler_metadata
:
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]
=
None
decode_num_splits
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
decode_metadata
(
self
):
decode_metadata
=
super
().
decode_metadata
# TODO: cache assignment?
if
decode_metadata
is
not
None
:
decode_metadata
.
decode_tile_scheduler_metadata
=
\
self
.
decode_tile_scheduler_metadata
decode_metadata
.
decode_num_splits
=
\
self
.
decode_num_splits
return
decode_metadata
class
FlashMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
FlashMLAMetadata
]):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
m
=
super
().
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
if
m
.
num_decode_tokens
>
0
:
m
.
decode_tile_scheduler_metadata
,
m
.
decode_num_splits
=
\
get_mla_metadata
(
m
.
seq_lens_tensor
[
m
.
num_prefills
:],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
return
m
class
FlashMLAState
(
MLACommonState
[
FlashMLAMetadata
]):
def
__init__
(
self
,
*
args
,
**
kwds
):
super
().
__init__
(
*
args
,
**
kwds
)
self
.
num_q_heads
=
self
.
runner
.
model_config
.
get_num_attention_heads
(
self
.
runner
.
parallel_config
)
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
# Run a dummy `get_mla_metadata` so we can get the right shapes
self
.
_graph_decoder_tile_scheduler_metadata
,
\
self
.
_graph_decode_num_splits
=
get_mla_metadata
(
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
),
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
with
super
().
graph_capture
(
max_batch_size
):
yield
del
self
.
_graph_decoder_tile_scheduler_metadata
del
self
.
_graph_decode_num_splits
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
):
metadata
=
super
().
graph_capture_get_metadata_for_batch
(
batch_size
,
is_encoder_decoder_model
)
assert
metadata
.
num_decode_tokens
>
0
decoder_tile_scheduler_metadata
,
decode_num_splits
=
get_mla_metadata
(
self
.
_graph_seq_lens
[:
batch_size
],
self
.
num_q_heads
,
1
,
# MQA for the decode path
)
self
.
_graph_decoder_tile_scheduler_metadata
.
copy_
(
decoder_tile_scheduler_metadata
)
self
.
_graph_decode_num_splits
[:
batch_size
+
1
].
copy_
(
decode_num_splits
)
metadata
.
decode_tile_scheduler_metadata
=
\
self
.
_graph_decoder_tile_scheduler_metadata
metadata
.
decode_num_splits
=
\
self
.
_graph_decode_num_splits
[:
batch_size
+
1
]
return
metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
=
super
().
get_graph_input_buffers
(
attn_metadata
,
is_encoder_decoder_model
)
input_buffers
[
"decode_tile_scheduler_metadata"
]
=
\
attn_metadata
.
decode_metadata
.
decode_tile_scheduler_metadata
input_buffers
[
"decode_num_splits"
]
=
\
attn_metadata
.
decode_metadata
.
decode_num_splits
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
super
().
prepare_graph_input_buffers
(
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
)
input_buffers
[
"decode_tile_scheduler_metadata"
].
copy_
(
attn_metadata
.
decode_metadata
.
decode_tile_scheduler_metadata
)
input_buffers
[
"decode_num_splits"
].
copy_
(
attn_metadata
.
decode_metadata
.
decode_num_splits
)
class
FlashMLAImpl
(
MLACommonImpl
[
FlashMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
is_supported
,
reason
=
is_flashmla_supported
()
assert
is_supported
,
reason
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
o
,
_
=
flash_mla_with_kvcache
(
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
block_table
=
decode_meta
.
block_tables
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
decode_meta
.
decode_tile_scheduler_metadata
,
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/mla/__init__.py
deleted
100644 → 0
View file @
af7dfb0d
vllm/attention/backends/mla/common.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
# MLA Common Components
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Runtime
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import
functools
from
abc
import
abstractmethod
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
itertools
import
accumulate
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
,
MLAAttentionImpl
)
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.attention.utils.fa_utils
import
get_flash_attn_version
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
UnquantizedLinearMethod
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
else
:
triton_attention
=
None
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
True
except
ImportError
:
is_vllm_fa
=
False
try
:
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
is_hip
=
current_platform
.
is_rocm
()
class
MLACommonBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"TRITON_MLA"
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
MLACommonMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"MLACommonMetadataBuilder"
]:
return
MLACommonMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"MLACommonState"
]:
return
MLACommonState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
ops
.
copy_blocks_mla
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
576
]
T
=
TypeVar
(
"T"
,
bound
=
"MLACommonMetadata"
)
class
MLACommonState
(
AttentionState
,
Generic
[
T
]):
def
__init__
(
self
,
runner
):
self
.
runner
=
runner
self
.
_is_graph_capturing
=
False
scheduler_config
=
runner
.
scheduler_config
self
.
model_config
=
runner
.
model_config
cache_config
=
runner
.
cache_config
self
.
chunked_prefill_enabled
=
scheduler_config
.
chunked_prefill_enabled
self
.
enable_prefix_caching
=
cache_config
.
enable_prefix_caching
if
self
.
chunked_prefill_enabled
or
self
.
enable_prefix_caching
:
self
.
context_chunk_workspace_size
=
min
(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max
(
8
*
self
.
model_config
.
max_model_len
,
4
*
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128
*
1024
)
assert
self
.
context_chunk_workspace_size
>=
\
scheduler_config
.
max_num_seqs
*
cache_config
.
block_size
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
self
.
_is_graph_capturing
=
True
self
.
_graph_slot_mapping
=
torch
.
full
((
max_batch_size
,
),
PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
self
.
_graph_seq_lens
=
torch
.
ones
(
max_batch_size
,
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
self
.
_graph_block_tables
=
torch
.
from_numpy
(
self
.
runner
.
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
)
self
.
_positions
=
torch
.
zeros
((
max_batch_size
,
),
dtype
=
torch
.
long
,
device
=
self
.
runner
.
device
)
yield
self
.
_is_graph_capturing
=
False
del
self
.
_graph_slot_mapping
del
self
.
_graph_seq_lens
del
self
.
_graph_block_tables
del
self
.
_positions
def
graph_clone
(
self
,
batch_size
:
int
):
assert
self
.
_is_graph_capturing
return
self
.
__class__
(
self
.
runner
)
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
)
->
T
:
assert
self
.
_is_graph_capturing
attn_metadata
=
self
.
runner
.
attn_backend
.
make_metadata
(
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
use_cuda_graph
=
True
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
batch_size
,
slot_mapping
=
self
.
_graph_slot_mapping
[:
batch_size
],
seq_lens
=
None
,
seq_lens_tensor
=
self
.
_graph_seq_lens
[:
batch_size
],
max_query_len
=
1
,
max_decode_query_len
=
1
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
runner
.
max_seq_len_to_capture
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
_graph_block_tables
[:
batch_size
],
head_dim
=
self
.
runner
.
model_config
.
get_head_size
())
if
is_encoder_decoder_model
:
raise
NotImplementedError
(
"MLACommonState does not support encoder/decoder yet"
)
return
attn_metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
=
{
"slot_mapping"
:
attn_metadata
.
slot_mapping
,
"seq_lens_tensor"
:
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
"block_tables"
:
attn_metadata
.
decode_metadata
.
block_tables
,
}
if
is_encoder_decoder_model
:
raise
NotImplementedError
(
"MLACommonState does not support encoder/decoder yet"
)
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
[
"seq_lens_tensor"
].
copy_
(
attn_metadata
.
decode_metadata
.
seq_lens_tensor
,
non_blocking
=
True
)
input_buffers
[
"block_tables"
].
copy_
(
attn_metadata
.
decode_metadata
.
block_tables
,
non_blocking
=
True
)
if
is_encoder_decoder_model
:
raise
NotImplementedError
(
"TritonMLAState does not support encoder/decoder yet"
)
def
begin_forward
(
self
,
model_input
):
if
self
.
chunked_prefill_enabled
or
self
.
enable_prefix_caching
:
if
not
hasattr
(
self
,
"context_chunk_workspace"
):
# not self.runner.device does not return the correct device
# for this process, (init_device sets the correct device but
# only on the Worker). The only way Ive figured out to get the
# correct device is to allocate the workspace on the first call
# to begin_forward and use the device of the input tokens
assert
model_input
.
input_tokens
is
not
None
self
.
context_chunk_workspace
=
torch
.
empty
(
(
self
.
context_chunk_workspace_size
,
self
.
model_config
.
get_head_size
()),
dtype
=
self
.
model_config
.
dtype
,
device
=
model_input
.
input_tokens
.
device
,
)
model_input
.
attn_metadata
.
context_chunk_workspace
=
\
self
.
context_chunk_workspace
@
dataclass
class
MLACommonMetadata
(
AttentionMetadata
):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
Optional
[
torch
.
Tensor
]
# Maximum query length in the batch.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
_cached_prefill_metadata
:
Optional
[
Any
]
=
None
_cached_decode_metadata
:
Optional
[
Any
]
=
None
num_prefill_tokens
:
int
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
# Used when chunked prefill is enabled to simulate worst case workspace
# allocations, hopefully to avoid going OOM
is_profile_run
:
bool
=
False
# New for MLA (compared to FlashAttention)
# For chunked prefill
context_chunk_cu_seq_lens
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_starts
:
Optional
[
torch
.
Tensor
]
=
None
context_chunk_seq_tot
:
Optional
[
List
[
int
]]
=
None
context_chunk_max_seq_lens
:
Optional
[
List
[
int
]]
=
None
# Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
context_chunk_workspace
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
supported_head_sizes
=
MLACommonBackend
.
get_supported_head_sizes
()
if
self
.
head_dim
is
not
None
and
self
.
head_dim
\
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
" received
{
self
.
head_dim
}
."
)
@
property
def
prefill_metadata
(
self
):
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
self
.
_cached_prefill_metadata
=
self
.
__class__
(
# Required by ModelRunner
use_cuda_graph
=
False
,
# Not Attention Related
# Required by Attention Metadata
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
# MLACommonMetadata
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_query_len
=
0
,
max_decode_seq_len
=
0
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
head_dim
=
self
.
head_dim
,
is_profile_run
=
self
.
is_profile_run
,
# MLACommonMetadata Chunk prefill specific
context_chunk_cu_seq_lens
=
self
.
context_chunk_cu_seq_lens
,
context_chunk_starts
=
self
.
context_chunk_starts
,
context_chunk_seq_tot
=
self
.
context_chunk_seq_tot
,
context_chunk_max_seq_lens
=
self
.
context_chunk_max_seq_lens
,
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
):
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
seq_lens_tensor
is
not
None
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
self
.
_cached_decode_metadata
=
self
.
__class__
(
# Required by ModelRunner
use_cuda_graph
=
self
.
use_cuda_graph
,
# Not Attention Related
# Required by Attention Metadata
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
False
,
# MLACommonMetadata
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_decode_query_len
=
self
.
max_decode_query_len
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc
=
(
self
.
query_start_loc
[
self
.
num_prefills
:]
-
self
.
query_start_loc
[
self
.
num_prefills
])
if
self
.
query_start_loc
is
not
None
else
None
,
seq_start_loc
=
self
.
seq_start_loc
[
self
.
num_prefills
:]
if
self
.
seq_start_loc
is
not
None
else
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
head_dim
=
self
.
head_dim
,
is_profile_run
=
self
.
is_profile_run
)
return
self
.
_cached_decode_metadata
class
MLACommonMetadataBuilder
(
AttentionMetadataBuilder
[
T
],
Generic
[
T
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[]
def
__init__
(
self
,
input_builder
):
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
sliding_window
=
input_builder
.
sliding_window
self
.
block_size
=
input_builder
.
block_size
self
.
chunked_prefill_enabled
=
\
self
.
runner
.
scheduler_config
.
chunked_prefill_enabled
self
.
enable_prefix_caching
=
\
self
.
runner
.
cache_config
.
enable_prefix_caching
if
self
.
chunked_prefill_enabled
or
self
.
enable_prefix_caching
:
attn_state
=
self
.
input_builder
.
runner
.
attn_state
self
.
context_chunk_workspace_size
=
\
attn_state
.
context_chunk_workspace_size
self
.
page_size
=
self
.
runner
.
block_size
def
prepare
(
self
):
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
prefill_seq_lens
:
List
[
int
]
=
[]
self
.
context_lens
:
List
[
int
]
=
[]
self
.
block_tables
:
List
[
List
[
int
]]
=
[]
self
.
curr_seq_lens
:
List
[
int
]
=
[]
self
.
multimodal_placeholder_maps
:
Dict
[
str
,
MultiModalPlaceholderMap
]
=
defaultdict
(
MultiModalPlaceholderMap
)
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
def
_get_graph_runner_block_tables
(
self
,
num_seqs
:
int
,
block_tables
:
List
[
List
[
int
]])
->
torch
.
Tensor
:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size
,
max_blocks
=
self
.
runner
.
graph_block_tables
.
shape
assert
max_batch_size
>=
num_seqs
graph_block_tables
=
self
.
runner
.
graph_block_tables
[:
num_seqs
]
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
num_blocks
=
len
(
block_table
)
if
num_blocks
<=
max_blocks
:
graph_block_tables
[
i
,
:
num_blocks
]
=
block_table
else
:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables
[
i
,
:
max_blocks
]
=
block_table
[:
max_blocks
]
return
torch
.
from_numpy
(
graph_block_tables
).
to
(
device
=
self
.
runner
.
device
,
non_blocking
=
True
)
def
build
(
self
,
seq_lens
:
List
[
int
],
query_lens
:
List
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
max_query_len
=
max
(
query_lens
)
decode_query_lens
=
query_lens
[
self
.
num_prefills
:]
if
len
(
decode_query_lens
)
>
0
:
max_decode_query_len
=
max
(
decode_query_lens
)
else
:
max_decode_query_len
=
1
max_prefill_seq_len
=
max
(
self
.
prefill_seq_lens
,
default
=
0
)
max_decode_seq_len
=
max
(
self
.
curr_seq_lens
,
default
=
0
)
num_decode_tokens
=
self
.
num_decode_tokens
query_start_loc
=
list
(
accumulate
(
query_lens
,
initial
=
0
))
seq_start_loc
=
list
(
accumulate
(
seq_lens
,
initial
=
0
))
num_seqs
=
len
(
seq_lens
)
if
use_captured_graph
:
self
.
slot_mapping
.
extend
([
PAD_SLOT_ID
]
*
cuda_graph_pad_size
)
self
.
block_tables
.
extend
(
self
.
__class__
.
BLOCK_TABLE_EXTENDER
*
cuda_graph_pad_size
)
num_decode_tokens
=
batch_size
-
self
.
num_prefill_tokens
block_tables
=
self
.
_get_graph_runner_block_tables
(
num_seqs
,
self
.
block_tables
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
(
"query_lens: {}"
.
format
(
query_lens
))
assert
device
is
not
None
context_lens_tensor
=
async_tensor_h2d
(
self
.
context_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
seq_lens_tensor
=
async_tensor_h2d
(
seq_lens
,
torch
.
int
,
device
,
self
.
runner
.
pin_memory
)
slot_mapping_tensor
=
async_tensor_h2d
(
self
.
slot_mapping
,
torch
.
long
,
device
,
self
.
runner
.
pin_memory
)
query_start_loc_tensor
=
async_tensor_h2d
(
query_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
seq_start_loc_tensor
=
async_tensor_h2d
(
seq_start_loc
,
torch
.
int32
,
device
,
self
.
runner
.
pin_memory
)
context_chunk_cu_seq_lens
=
None
context_chunk_starts
=
None
context_chunk_seq_tot
=
None
context_chunk_max_seq_lens
=
None
if
(
self
.
chunked_prefill_enabled
or
self
.
enable_prefix_caching
)
\
and
self
.
num_prefills
>
0
\
and
context_lens_tensor
is
not
None
\
and
context_lens_tensor
[:
self
.
num_prefills
].
max
()
>
0
:
# NOTE: it is recommended you read the `Chunked Prefill` section in
# the comment at the top of the file before trying to understand
# the following code
num_prefills_with_context
=
\
(
context_lens_tensor
[:
self
.
num_prefills
]
>
0
).
sum
().
item
()
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk
=
\
self
.
context_chunk_workspace_size
//
num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_and_maybe_dequant_cache` kernel cannot
# handle `context_chunk_starts` that are not aligned to page_size
max_context_chunk
=
round_down
(
max_context_chunk
,
self
.
page_size
)
assert
max_context_chunk
>
0
num_chunks
=
cdiv
(
context_lens_tensor
.
max
(),
max_context_chunk
)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
context_chunk_starts
=
\
torch
.
arange
(
num_chunks
,
device
=
device
,
dtype
=
torch
.
int32
)
\
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_prefills
)
\
*
max_context_chunk
chunk_ends
=
torch
.
min
(
context_lens_tensor
[:
self
.
num_prefills
]
\
.
unsqueeze
(
0
),
context_chunk_starts
+
max_context_chunk
)
chunk_seq_lens
=
(
chunk_ends
-
context_chunk_starts
).
clamp
(
min
=
0
)
_context_chunk_cu_seq_lens
=
chunk_seq_lens
.
cumsum
(
dim
=
1
).
to
(
torch
.
int32
)
zero
=
torch
.
zeros
(
num_chunks
,
dtype
=
torch
.
int32
,
device
=
device
)
\
.
unsqueeze
(
-
1
)
context_chunk_cu_seq_lens
=
\
torch
.
cat
([
zero
,
_context_chunk_cu_seq_lens
],
dim
=
1
)
context_chunk_max_seq_lens
=
\
chunk_seq_lens
.
max
(
dim
=
1
).
values
.
tolist
()
context_chunk_seq_tot
=
chunk_seq_lens
.
sum
(
dim
=
1
).
tolist
()
assert
max
(
context_chunk_seq_tot
)
<=
\
self
.
context_chunk_workspace_size
return
self
.
runner
.
attn_backend
.
make_metadata
(
# Required by ModelRunner
use_cuda_graph
=
use_captured_graph
,
# Not Attention Related
# Required by Attention Metadata
num_prefills
=
self
.
num_prefills
,
slot_mapping
=
slot_mapping_tensor
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
num_decode_tokens
,
# Required by Attention Metadata (not used)
multi_modal_placeholder_index_maps
=
None
,
# Not Attention Related
enable_kv_scales_calculation
=
False
,
# MLACommonMetadata
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_decode_query_len
=
max_decode_query_len
,
max_prefill_seq_len
=
max_prefill_seq_len
,
max_decode_seq_len
=
max_decode_seq_len
,
query_start_loc
=
query_start_loc_tensor
,
seq_start_loc
=
seq_start_loc_tensor
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
head_dim
=
self
.
runner
.
model_config
.
get_head_size
(),
is_profile_run
=
self
.
runner
.
in_profile_run
,
# MLACommonMetadata Chunk prefill specific
context_chunk_cu_seq_lens
=
context_chunk_cu_seq_lens
,
context_chunk_starts
=
context_chunk_starts
,
context_chunk_seq_tot
=
context_chunk_seq_tot
,
context_chunk_max_seq_lens
=
context_chunk_max_seq_lens
,
)
class
MLACommonImpl
(
MLAAttentionImpl
[
T
],
Generic
[
T
]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
q_lora_rank
:
Optional
[
int
],
kv_lora_rank
:
int
,
qk_nope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
qk_head_dim
:
int
,
v_head_dim
:
int
,
kv_b_proj
:
ColumnParallelLinear
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing not supported in V0."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
q_lora_rank
=
q_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
self
.
qk_nope_head_dim
=
qk_nope_head_dim
self
.
qk_rope_head_dim
=
qk_rope_head_dim
self
.
qk_head_dim
=
qk_head_dim
self
.
v_head_dim
=
v_head_dim
self
.
kv_b_proj
=
kv_b_proj
self
.
triton_fa_func
=
triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
\
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self
.
_pad_v
=
self
.
vllm_flash_attn_version
is
None
or
not
(
self
.
vllm_flash_attn_version
==
3
and
current_platform
.
get_device_capability
()[
0
]
==
9
)
def
_flash_attn_varlen_diff_headdims
(
self
,
q
,
k
,
v
,
softmax_scale
,
return_softmax_lse
,
**
kwargs
):
maybe_padded_v
=
v
if
self
.
_pad_v
:
maybe_padded_v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
\
and
not
return_softmax_lse
:
attn_out
=
self
.
triton_fa_func
(
q
,
k
,
maybe_padded_v
,
None
,
# output
kwargs
[
"cu_seqlens_q"
],
kwargs
[
"cu_seqlens_k"
],
kwargs
[
"max_seqlen_q"
],
kwargs
[
"max_seqlen_k"
],
kwargs
[
"causal"
],
softmax_scale
,
None
,
# bias
)
elif
is_vllm_fa
:
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_softmax_lse
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
else
:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out
=
self
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
maybe_padded_v
,
return_attn_probs
=
return_softmax_lse
,
softmax_scale
=
softmax_scale
,
**
kwargs
,
)
# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest
=
None
if
isinstance
(
attn_out
,
tuple
):
attn_out
,
*
rest
=
attn_out
# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if
return_softmax_lse
:
assert
rest
is
not
None
return
attn_out
,
rest
[
0
]
return
attn_out
def
_v_up_proj
(
self
,
x
):
# Convert from (B, N, L) to (N, B, L)
x
=
x
.
view
(
-
1
,
self
.
num_heads
,
self
.
kv_lora_rank
).
transpose
(
0
,
1
)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x
=
torch
.
bmm
(
x
,
self
.
W_UV
)
# Convert from (N, B, V) to (B, N * V)
return
x
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
v_head_dim
)
def
process_weights_after_loading
(
self
,
act_dtype
:
torch
.
dtype
):
def
get_layer_weight
(
layer
):
WEIGHT_NAMES
=
(
"weight"
,
"qweight"
,
"weight_packed"
)
for
attr
in
WEIGHT_NAMES
:
if
hasattr
(
layer
,
attr
):
return
getattr
(
layer
,
attr
)
raise
AttributeError
(
f
"Layer '
{
layer
}
' has no recognized weight attribute:"
f
"
{
WEIGHT_NAMES
}
."
)
def
get_and_maybe_dequant_weights
(
layer
:
LinearBase
):
if
not
isinstance
(
layer
.
quant_method
,
UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye
=
torch
.
eye
(
layer
.
input_size_per_partition
,
dtype
=
act_dtype
,
device
=
get_layer_weight
(
layer
).
device
)
dequant_weights
=
layer
.
quant_method
.
apply
(
layer
,
eye
,
bias
=
None
)
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight
=
get_and_maybe_dequant_weights
(
self
.
kv_b_proj
).
T
assert
kv_b_proj_weight
.
shape
==
(
self
.
kv_lora_rank
,
self
.
num_heads
*
(
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)),
(
f
"
{
kv_b_proj_weight
.
shape
=
}
, "
f
"
{
self
.
kv_lora_rank
=
}
, "
f
"
{
self
.
num_heads
=
}
, "
f
"
{
self
.
qk_nope_head_dim
=
}
, "
f
"
{
self
.
v_head_dim
=
}
"
)
kv_b_proj_weight
=
kv_b_proj_weight
.
view
(
self
.
kv_lora_rank
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
,
)
W_UK
,
W_UV
=
kv_b_proj_weight
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
# Convert from (L, N, V) to (N, L, V)
self
.
W_UV
=
W_UV
.
transpose
(
0
,
1
)
# Convert from (L, N, P) to (N, P, L)
self
.
W_UK_T
=
W_UK
.
permute
(
1
,
2
,
0
)
def
_compute_prefill_context
(
self
,
q
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
):
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
assert
prefill_metadata
.
context_chunk_seq_tot
is
not
None
assert
prefill_metadata
.
context_chunk_cu_seq_lens
is
not
None
assert
prefill_metadata
.
context_chunk_starts
is
not
None
assert
prefill_metadata
.
context_chunk_max_seq_lens
is
not
None
assert
prefill_metadata
.
context_lens_tensor
is
not
None
output
=
None
iters
=
len
(
prefill_metadata
.
context_chunk_seq_tot
)
# Fetch from attn_metadata directly, since it late bound by
# MLAAttentionState, grabbing it directly `attn_metadata` can avoid
# any weirdness around prefill_metadata caching
assert
attn_metadata
.
context_chunk_workspace
is
not
None
workspace
=
attn_metadata
.
context_chunk_workspace
for
i
in
range
(
iters
):
toks
=
prefill_metadata
.
context_chunk_seq_tot
[
i
]
ops
.
gather_and_maybe_dequant_cache
(
src_cache
=
kv_c_and_k_pe_cache
,
dst
=
workspace
,
block_table
=
prefill_metadata
.
block_tables
,
cu_seq_lens
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
batch_size
=
prefill_metadata
.
num_prefills
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
k_scale
,
seq_starts
=
prefill_metadata
.
context_chunk_starts
[
i
],
)
kv_c_normed
=
workspace
[:
toks
]
\
[...,
:
self
.
kv_lora_rank
]
k_pe
=
workspace
[:
toks
]
\
[...,
self
.
kv_lora_rank
:].
unsqueeze
(
1
)
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
attn_output
,
attn_softmax_lse
=
\
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
softmax_scale
=
self
.
scale
,
causal
=
False
,
# Context is unmasked
return_softmax_lse
=
True
,
)
if
output
is
None
:
output
=
attn_output
output_lse
=
attn_softmax_lse
else
:
output_tmp
=
torch
.
empty_like
(
output
)
output_lse_tmp
=
torch
.
empty_like
(
output_lse
)
merge_attn_states
(
output
=
output_tmp
,
output_lse
=
output_lse_tmp
,
prefix_output
=
output
,
prefix_lse
=
output_lse
,
suffix_output
=
attn_output
,
suffix_lse
=
attn_softmax_lse
,
)
output
=
output_tmp
output_lse
=
output_lse_tmp
return
output
,
output_lse
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
k_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
has_context
=
prefill_metadata
.
context_lens_tensor
is
not
None
\
and
prefill_metadata
.
context_lens_tensor
.
max
()
>
0
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
output
=
self
.
_flash_attn_varlen_diff_headdims
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
prefill_metadata
.
query_start_loc
,
cu_seqlens_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_metadata
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
return_softmax_lse
=
has_context
,
)
if
has_context
:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output
,
suffix_lse
=
output
context_output
,
context_lse
=
self
.
_compute_prefill_context
(
\
q
,
kv_c_and_k_pe_cache
,
attn_metadata
,
k_scale
)
output
=
torch
.
empty_like
(
suffix_output
)
merge_attn_states
(
output
=
output
,
prefix_output
=
context_output
,
prefix_lse
=
context_lse
,
suffix_output
=
suffix_output
,
suffix_lse
=
suffix_lse
,
)
# unpad if necessary
if
self
.
_pad_v
:
output
=
output
[...,
:
v
.
shape
[
-
1
]]
return
output
.
flatten
(
start_dim
=-
2
)
@
abstractmethod
def
_forward_decode
(
self
,
ql_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
layer
:
AttentionLayer
,
q
:
torch
.
Tensor
,
# query in unified attn
k_c_normed
:
torch
.
Tensor
,
# key in unified attn
k_pe
:
torch
.
Tensor
,
# value in unified attn
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
output
is
not
None
:
raise
NotImplementedError
(
"output is not yet supported for MLAImplBase"
)
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for MLAImplBase"
)
if
attn_metadata
.
is_profile_run
and
\
attn_metadata
.
context_chunk_workspace
is
not
None
:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_
=
torch
.
empty
(
(
attn_metadata
.
context_chunk_workspace
.
shape
[
0
],
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
),
device
=
k_c_normed
.
device
,
dtype
=
k_c_normed
.
dtype
,
)
has_decode
=
attn_metadata
.
decode_metadata
is
not
None
has_prefill
=
attn_metadata
.
prefill_metadata
is
not
None
num_prefill_tokens
:
int
=
attn_metadata
.
num_prefill_tokens
q
=
q
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_head_dim
)
decode_q
=
q
[
num_prefill_tokens
:]
prefill_q
=
q
[:
num_prefill_tokens
]
prefill_k_pe
=
k_pe
[:
num_prefill_tokens
]
prefill_k_c_normed
=
k_c_normed
[:
num_prefill_tokens
]
# write the latent and rope to kv cache
if
kv_cache
.
numel
()
>
0
:
ops
.
concat_and_cache_mla
(
k_c_normed
,
k_pe
.
squeeze
(
1
),
kv_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
=
self
.
kv_cache_dtype
,
scale
=
layer
.
_k_scale
,
)
output
=
torch
.
empty
(
attn_metadata
.
num_prefill_tokens
+
attn_metadata
.
num_decode_tokens
,
self
.
v_head_dim
*
self
.
num_heads
,
device
=
q
.
device
,
dtype
=
q
.
dtype
)
if
has_prefill
:
output
[:
num_prefill_tokens
]
=
self
.
_forward_prefill
(
prefill_q
,
prefill_k_c_normed
,
prefill_k_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
)
if
has_decode
:
decode_q_nope
,
decode_q_pe
=
decode_q
.
split
(
[
self
.
qk_nope_head_dim
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope
=
decode_q_nope
.
transpose
(
0
,
1
)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope
=
torch
.
bmm
(
decode_q_nope
,
self
.
W_UK_T
)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
return
output
vllm/attention/backends/rocm_aiter_mla.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Type
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
,
MLACommonMetadataBuilder
,
MLACommonState
)
from
vllm.attention.backends.utils
import
(
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.rocm_aiter_mla
import
(
aiter_mla_decode_fwd
,
get_aiter_mla_metadata
)
def
is_aiter_mla_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER
\
and
envs
.
VLLM_ROCM_USE_AITER_MLA
class
AiterMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_AITER_MLA"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"AiterMLAImpl"
]:
return
AiterMLAImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AiterMLAMetadata"
]:
return
AiterMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"AiterMLAMetadataBuilder"
]:
return
AiterMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"AiterMLAState"
]:
return
AiterMLAState
@
dataclass
class
AiterMLAMetadata
(
MLACommonMetadata
):
# The following 5 tensors are for current version of AITER MLA
block_table_bound
:
Optional
[
torch
.
Tensor
]
=
None
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_lens
:
Optional
[
torch
.
Tensor
]
=
None
# This is just to make new AITER MLA API work
# -- MTP support is not added yet.
qo_indptr
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
):
prefill_metadata
=
super
().
prefill_metadata
self
.
_cached_prefill_metadata
=
prefill_metadata
if
prefill_metadata
is
not
None
:
prefill_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
prefill_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
prefill_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
prefill_metadata
.
block_table_bound
=
self
.
block_table_bound
prefill_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
self
.
_cached_prefill_metadata
=
self
.
__class__
(
**
prefill_metadata
.
__dict__
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
):
decode_metadata
=
super
().
decode_metadata
self
.
_cached_decode_metadata
=
decode_metadata
if
decode_metadata
is
not
None
:
decode_metadata
.
paged_kv_indptr
=
self
.
paged_kv_indptr
decode_metadata
.
paged_kv_indices
=
self
.
paged_kv_indices
decode_metadata
\
.
paged_kv_last_page_lens
=
self
.
paged_kv_last_page_lens
decode_metadata
.
block_table_bound
=
self
.
block_table_bound
decode_metadata
.
qo_indptr
=
self
.
qo_indptr
# update the cache
self
.
_cached_decode_metadata
=
self
.
__class__
(
**
decode_metadata
.
__dict__
)
return
self
.
_cached_decode_metadata
class
AiterMLAMetadataBuilder
(
MLACommonMetadataBuilder
[
AiterMLAMetadata
]):
BLOCK_TABLE_EXTENDER
:
list
[
list
[
int
]]
=
[[]]
def
__init__
(
self
,
input_builder
):
super
().
__init__
(
input_builder
)
assert
self
.
block_size
==
1
,
"AITER MLA requires only block size 1."
def
prepare
(
self
):
super
().
prepare
()
self
.
paged_kv_indices
:
list
[
int
]
=
[]
self
.
paged_kv_indptr
:
list
[
int
]
=
[
0
]
self
.
paged_kv_last_page_lens
:
list
[
int
]
=
[]
self
.
total_blocks
=
0
self
.
qo_indptr
:
list
[
int
]
=
[
0
]
def
_add_seq_group
(
self
,
inter_data
,
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt
=
inter_data
.
is_prompt
block_tables
=
inter_data
.
block_tables
for
(
seq_id
,
token_len
,
seq_len
,
curr_seq_len
,
query_len
,
context_len
,
curr_sliding_window_block
)
in
zip
(
inter_data
.
seq_ids
,
[
len
(
t
)
for
t
in
inter_data
.
input_tokens
],
inter_data
.
orig_seq_lens
,
inter_data
.
seq_lens
,
inter_data
.
query_lens
,
inter_data
.
context_lens
,
inter_data
.
curr_sliding_window_blocks
):
self
.
context_lens
.
append
(
context_len
)
if
is_prompt
:
self
.
num_prefills
+=
1
self
.
num_prefill_tokens
+=
token_len
self
.
prefill_seq_lens
.
append
(
seq_len
)
else
:
self
.
num_decode_tokens
+=
query_len
self
.
curr_seq_lens
.
append
(
curr_seq_len
)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
elif
((
chunked_prefill_enabled
or
not
is_prompt
)
and
block_tables
is
not
None
):
if
curr_sliding_window_block
==
0
:
block_table
=
block_tables
[
seq_id
]
else
:
block_table
=
block_tables
[
seq_id
][
-
curr_sliding_window_block
:]
self
.
block_tables
.
append
(
block_table
)
# Compute slot mapping.
is_profile_run
=
is_block_tables_empty
(
block_tables
)
start_idx
=
compute_slot_mapping_start_idx
(
is_prompt
,
query_len
,
context_len
,
self
.
sliding_window
)
compute_slot_mapping
(
is_profile_run
,
self
.
slot_mapping
,
seq_id
,
seq_len
,
context_len
,
start_idx
,
self
.
block_size
,
inter_data
.
block_tables
)
if
is_profile_run
:
return
# Update paged_kv_* tensors only for non-profile run
block_table
=
block_tables
[
seq_id
]
self
.
_update_paged_kv_tensors
(
block_table
,
seq_len
)
def
_update_paged_kv_tensors
(
self
,
block_table
:
list
[
int
],
seq_len
:
int
):
# Get the number of valid blocks based on sequence length.
# If seq_len = 16, block_size = 16,
# block_table_bound is 1 with 1 valid block.
# If seq_len = 15, block_size = 16,
# block_table_bound is 0 + 1 with 1 valid block.
self
.
total_blocks
+=
len
(
block_table
)
block_table_bound
=
seq_len
//
self
.
block_size
+
1
\
if
seq_len
%
self
.
block_size
!=
0
\
else
seq_len
//
self
.
block_size
self
.
paged_kv_indices
.
extend
(
block_table
[:
block_table_bound
])
self
.
paged_kv_indptr
.
append
(
self
.
paged_kv_indptr
[
-
1
]
+
block_table_bound
)
self
.
qo_indptr
.
append
(
self
.
qo_indptr
[
-
1
]
+
1
)
last_page_len
=
seq_len
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
self
.
paged_kv_last_page_lens
.
append
(
last_page_len
)
def
build
(
self
,
seq_lens
:
list
[
int
],
query_lens
:
list
[
int
],
cuda_graph_pad_size
:
int
,
batch_size
:
int
)
->
AiterMLAMetadata
:
metadata
=
super
().
build
(
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
)
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
if
use_captured_graph
:
last_paged_kv_indptr
=
self
.
paged_kv_indptr
[
-
1
]
self
.
paged_kv_indptr
.
extend
([
last_paged_kv_indptr
]
*
cuda_graph_pad_size
)
self
.
paged_kv_last_page_lens
.
extend
([
0
]
*
cuda_graph_pad_size
)
last_qo_indptr
=
self
.
qo_indptr
[
-
1
]
self
.
qo_indptr
.
extend
([
last_qo_indptr
]
*
cuda_graph_pad_size
)
# For current version of AITER MLA
if
len
(
self
.
paged_kv_indptr
)
>
0
:
# extend to the maximum number of blocks as returned by the
# scheduler
self
.
paged_kv_indices
.
extend
(
[
0
]
*
(
self
.
total_blocks
-
len
(
self
.
paged_kv_indices
)))
paged_kv_indices_tensor
=
torch
.
tensor
(
self
.
paged_kv_indices
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_indptr_tensor
=
torch
.
tensor
(
self
.
paged_kv_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
paged_kv_last_page_lens_tensor
=
torch
.
tensor
(
self
.
paged_kv_last_page_lens
,
device
=
device
,
dtype
=
torch
.
int
)
block_table_bound_tensor
=
torch
.
zeros
(
len
(
self
.
paged_kv_indptr
)
-
1
,
device
=
device
,
dtype
=
torch
.
int
)
qo_indptr
=
torch
.
tensor
(
self
.
qo_indptr
,
device
=
device
,
dtype
=
torch
.
int
)
else
:
paged_kv_indices_tensor
=
None
paged_kv_indptr_tensor
=
None
paged_kv_last_page_lens_tensor
=
None
block_table_bound_tensor
=
None
qo_indptr
=
None
metadata
.
paged_kv_indptr
=
paged_kv_indptr_tensor
metadata
.
paged_kv_indices
=
paged_kv_indices_tensor
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens_tensor
metadata
.
block_table_bound
=
block_table_bound_tensor
metadata
.
qo_indptr
=
qo_indptr
return
metadata
class
AiterMLAState
(
MLACommonState
[
AiterMLAMetadata
]):
@
contextmanager
def
graph_capture
(
self
,
max_batch_size
:
int
):
kv_indices
,
kv_indptr
,
last_page_lens
,
qo_indptr
=
\
get_aiter_mla_metadata
(
max_batch_size
=
max_batch_size
,
block_size
=
self
.
runner
.
block_size
,
max_block_per_batch
=
\
self
.
runner
.
get_max_block_per_batch
(),
device
=
self
.
runner
.
device
)
self
.
_paged_kv_indices_tensor
=
kv_indices
self
.
_paged_kv_indptr_tensor
=
kv_indptr
self
.
_paged_kv_last_page_lens_tensor
=
last_page_lens
self
.
_qo_indptr_tensor
=
qo_indptr
with
super
().
graph_capture
(
max_batch_size
):
yield
del
self
.
_paged_kv_indices_tensor
del
self
.
_paged_kv_indptr_tensor
del
self
.
_paged_kv_last_page_lens_tensor
del
self
.
_qo_indptr_tensor
def
graph_capture_get_metadata_for_batch
(
self
,
batch_size
:
int
,
is_encoder_decoder_model
:
bool
=
False
)
->
AiterMLAMetadata
:
metadata
=
super
().
graph_capture_get_metadata_for_batch
(
batch_size
,
is_encoder_decoder_model
)
paged_kv_indptr
=
self
.
_paged_kv_indptr_tensor
[:
batch_size
+
1
]
paged_kv_indices
=
self
.
_paged_kv_indices_tensor
paged_kv_last_page_lens
=
self
.
_paged_kv_last_page_lens_tensor
[:
batch_size
]
qo_indptr
=
self
.
_qo_indptr_tensor
[:
batch_size
+
1
]
metadata
.
paged_kv_indptr
=
paged_kv_indptr
metadata
.
paged_kv_indices
=
paged_kv_indices
metadata
.
paged_kv_last_page_lens
=
paged_kv_last_page_lens
metadata
.
qo_indptr
=
qo_indptr
return
metadata
def
get_graph_input_buffers
(
self
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
input_buffers
=
super
().
get_graph_input_buffers
(
attn_metadata
,
is_encoder_decoder_model
)
input_buffers
[
'paged_kv_indptr'
]
=
attn_metadata
.
decode_metadata
.
paged_kv_indptr
input_buffers
[
"paged_kv_indices"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_indices
input_buffers
[
"paged_kv_last_page_lens"
]
=
attn_metadata
.
\
decode_metadata
.
paged_kv_last_page_lens
input_buffers
[
'qo_indptr'
]
=
attn_metadata
.
qo_indptr
return
input_buffers
def
prepare_graph_input_buffers
(
self
,
input_buffers
,
attn_metadata
:
AiterMLAMetadata
,
is_encoder_decoder_model
:
bool
=
False
):
super
().
prepare_graph_input_buffers
(
input_buffers
,
attn_metadata
,
is_encoder_decoder_model
)
num_total_blocks
=
attn_metadata
.
decode_metadata
.
paged_kv_indices
.
shape
[
0
]
input_buffers
[
"paged_kv_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indptr
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_indices"
][:
num_total_blocks
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_indices
,
non_blocking
=
True
)
input_buffers
[
"paged_kv_last_page_lens"
].
copy_
(
attn_metadata
.
decode_metadata
.
paged_kv_last_page_lens
,
non_blocking
=
True
)
input_buffers
[
"qo_indptr"
].
copy_
(
attn_metadata
.
decode_metadata
.
qo_indptr
,
non_blocking
=
True
)
class
AiterMLAImpl
(
MLACommonImpl
[
AiterMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
list
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"Aiter MLA does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
from
aiter
import
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
def
_flash_attn_varlen_diff_headdims
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
softmax_scale
:
float
,
return_softmax_lse
:
bool
,
**
kwargs
)
->
Union
[
tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
]:
output
=
self
.
flash_attn_varlen_func
(
q
,
k
,
v
,
**
kwargs
,
)
return
output
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
AiterMLAMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
empty
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
kv_buffer
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
aiter_mla_decode_fwd
(
q
,
kv_buffer
,
o
,
self
.
scale
,
attn_metadata
.
qo_indptr
,
attn_metadata
.
max_query_len
,
attn_metadata
.
paged_kv_indptr
,
attn_metadata
.
paged_kv_indices
,
attn_metadata
.
paged_kv_last_page_lens
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/rocm_flash_attn.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer ROCm GPUs."""
import
itertools
from
dataclasses
import
dataclass
from
functools
import
cache
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
QuantKey
,
kFp8StaticTensorSym
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE_ROCM
=
256
@
cache
def
is_rocm_aiter_paged_attn_enabled
()
->
bool
:
return
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
\
and
envs
.
VLLM_ROCM_USE_AITER
\
@
cache
def
_get_paged_attn_module
()
->
PagedAttention
:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if
is_rocm_aiter_paged_attn_enabled
():
# Import AITERPagedAttention only when the flag is enabled
from
vllm.attention.ops.rocm_aiter_paged_attn
import
(
AITERPagedAttention
)
return
AITERPagedAttention
()
return
PagedAttention
()
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
def
get_name
()
->
str
:
return
"ROCM_FLASH"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"ROCmFlashAttentionImpl"
]:
return
ROCmFlashAttentionImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
ROCmFlashAttentionMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"ROCmFlashAttentionMetadataBuilder"
]:
return
ROCmFlashAttentionMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
paged_attn
=
_get_paged_attn_module
()
return
paged_attn
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
paged_attn
=
_get_paged_attn_module
()
paged_attn
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
ROCmFlashAttentionMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
_cached_prefill_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"ROCmFlashAttentionMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
return
self
.
_cached_prefill_metadata
assert
self
.
seq_lens
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
assert
self
.
block_tables
is
not
None
self
.
_cached_prefill_metadata
=
ROCmFlashAttentionMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
self
.
slot_mapping
[:
self
.
num_prefill_tokens
],
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
self
.
seq_lens
[:
self
.
num_prefills
],
seq_lens_tensor
=
self
.
seq_lens_tensor
[:
self
.
num_prefills
],
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
],
seq_start_loc
=
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
],
context_lens_tensor
=
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
],
block_tables
=
self
.
block_tables
[:
self
.
num_prefills
],
use_cuda_graph
=
False
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"ROCmFlashAttentionMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
return
self
.
_cached_decode_metadata
assert
self
.
block_tables
is
not
None
assert
self
.
seq_lens_tensor
is
not
None
self
.
_cached_decode_metadata
=
ROCmFlashAttentionMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
self
.
slot_mapping
[
self
.
num_prefill_tokens
:],
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens
=
None
,
seq_lens_tensor
=
self
.
seq_lens_tensor
[
self
.
num_prefills
:],
max_query_len
=
None
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
query_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
self
.
block_tables
[
self
.
num_prefills
:],
use_cuda_graph
=
self
.
use_cuda_graph
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
class
ROCmFlashAttentionMetadataBuilder
(
CommonMetadataBuilder
[
ROCmFlashAttentionMetadata
]):
_metadata_cls
=
ROCmFlashAttentionMetadata
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
seq_lens
:
Optional
[
List
[
int
]],
make_attn_mask
:
bool
=
True
)
->
List
[
torch
.
Tensor
]:
attn_biases
=
[]
if
seq_lens
:
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
(
(
num_heads
,
1
,
1
)).
to
(
alibi_slopes
.
device
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
if
make_attn_mask
:
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
).
to
(
alibi_slopes
.
device
)
attn_biases
.
append
((
bias
+
inf_mask
).
to
(
dtype
))
else
:
attn_biases
.
append
(
bias
.
to
(
dtype
))
return
attn_biases
def
_get_seq_len_block_table_args
(
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_type
:
str
,
)
->
tuple
:
'''
The particular choice of sequence-length
attributes which should be extracted from attn_metadata is dependent
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths
Encoder attn -> select encoder sequence lengths fields
Encoder-only attn -> select prefill sequence lengths with
bidirectional attention
Arguments:
* attn_metadata: Attention metadata structure associated with attention op
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention, encoder-only
Returns:
* Appropriate sequence-lengths tensors for query and key
* Appropriate max sequence-length scalar
* Causal masking flag
'''
if
attn_type
==
AttentionType
.
ENCODER
:
assert
attn_metadata
.
encoder_seq_lens
is
not
None
assert
attn_metadata
.
encoder_seq_lens_tensor
is
not
None
query_seq_start_loc
=
torch
.
tensor
(
list
(
itertools
.
accumulate
([
0
]
+
attn_metadata
.
encoder_seq_lens
)),
device
=
attn_metadata
.
encoder_seq_lens_tensor
.
device
,
dtype
=
attn_metadata
.
encoder_seq_lens_tensor
.
dtype
)
causal_mask
=
False
# No block tables associated with encoder attention
return
(
query_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
,
query_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
encoder_seq_lens
,
causal_mask
)
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
# For encoder-only models, we use the prefill sequence lengths
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens_tensor
is
not
None
query_seq_start_loc
=
torch
.
tensor
(
list
(
itertools
.
accumulate
([
0
]
+
attn_metadata
.
seq_lens
)),
device
=
attn_metadata
.
seq_lens_tensor
.
device
,
dtype
=
attn_metadata
.
seq_lens_tensor
.
dtype
)
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
# Encoder-only models typically use bidirectional attention
causal_mask
=
False
return
(
query_seq_start_loc
,
max_seq_len
,
query_seq_start_loc
,
max_seq_len
,
attn_metadata
.
seq_lens
,
causal_mask
)
elif
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens_tensor
is
not
None
query_seq_start_loc
=
torch
.
tensor
(
list
(
itertools
.
accumulate
([
0
]
+
attn_metadata
.
seq_lens
)),
device
=
attn_metadata
.
seq_lens_tensor
.
device
,
dtype
=
attn_metadata
.
seq_lens_tensor
.
dtype
)
max_seq_len
=
attn_metadata
.
max_prefill_seq_len
causal_mask
=
True
return
(
query_seq_start_loc
,
max_seq_len
,
query_seq_start_loc
,
max_seq_len
,
attn_metadata
.
seq_lens
,
causal_mask
)
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
encoder_seq_lens_tensor
is
not
None
query_start_loc
=
torch
.
tensor
(
list
(
itertools
.
accumulate
([
0
]
+
attn_metadata
.
seq_lens
)),
device
=
attn_metadata
.
encoder_seq_lens_tensor
.
device
,
dtype
=
attn_metadata
.
encoder_seq_lens_tensor
.
dtype
)
assert
attn_metadata
.
encoder_seq_lens
is
not
None
assert
attn_metadata
.
seq_lens_tensor
is
not
None
key_seq_start_loc
=
torch
.
tensor
(
list
(
itertools
.
accumulate
([
0
]
+
attn_metadata
.
encoder_seq_lens
)),
device
=
attn_metadata
.
seq_lens_tensor
.
device
,
dtype
=
attn_metadata
.
seq_lens_tensor
.
dtype
)
causal_mask
=
False
# Enc/dec cross-attention KVs match encoder sequence length;
# cross-attention utilizes special "cross" block tables
return
(
query_start_loc
,
attn_metadata
.
max_prefill_seq_len
,
key_seq_start_loc
,
attn_metadata
.
max_encoder_seq_len
,
attn_metadata
.
seq_lens
,
causal_mask
)
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
ROCmFlashAttentionImpl
(
AttentionImpl
):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"ROCM_FLASH backend."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in ROCm Flash Attention is not supported yet, it "
"will fail back to global attention for long context."
)
if
use_irope
:
logger
.
warning
(
"Using irope in V0 is not supported yet, it will fall back "
"to global attention for long context."
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
self
.
logits_soft_cap
=
0.0
else
:
self
.
logits_soft_cap
=
logits_soft_cap
self
.
attn_type
=
attn_type
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
paged_attn_module
=
_get_paged_attn_module
()
supported_head_sizes
=
self
.
paged_attn_module
.
get_supported_head_sizes
(
)
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
if
self
.
use_triton_flash_attn
:
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCm Triton FlashAttention does not support attention"
" logits soft capping."
" please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
self
.
triton_attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
"sliding window attention. If using half "
"precision, please try using the ROCm CK "
"FA backend instead by setting the env var "
"`VLLM_USE_TRITON_FLASH_ATTN=0`"
)
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
if
not
current_platform
.
has_device_capability
(
90
):
self
.
use_naive_attn
=
True
else
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
fa_attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
if
self
.
use_naive_attn
:
if
logits_soft_cap
is
not
None
:
raise
ValueError
(
"ROCm Naive FlashAttention does not support "
"attention logits soft capping."
)
self
.
sdpa_attn_func
=
_sdpa_attention
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
self
.
aiter_kv_scales_initialized
=
False
self
.
force_fp8_attention
=
(
get_current_vllm_config
()
is
not
None
and
get_current_vllm_config
().
model_config
.
override_attention_dtype
==
"fp8"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens
,
n_kv_heads
,
head_dim
=
x
.
shape
return
(
x
[:,
:,
None
,
:].
expand
(
tokens
,
n_kv_heads
,
n_rep
,
head_dim
).
reshape
(
tokens
,
n_kv_heads
*
n_rep
,
head_dim
))
def
fused_output_quant_supported
(
self
,
quant_key
:
QuantKey
):
if
self
.
use_triton_flash_attn
:
return
quant_key
==
kFp8StaticTensorSym
# Only supported in the Triton backend
return
False
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* ROCmFlashAttentionImpl.forward() may be invoked for both self- and
cross-attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
* ENCODER_ONLY: bidirectional attention with no KV caching;
use prefill sequence attributes
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
and
not
self
.
use_triton_flash_attn
:
raise
NotImplementedError
(
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now"
)
if
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl"
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
paged_attn
=
self
.
paged_attn_module
# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if
(
is_rocm_aiter_paged_attn_enabled
()
and
kv_cache
.
dtype
.
itemsize
==
1
and
not
self
.
aiter_kv_scales_initialized
and
kv_cache
.
shape
!=
torch
.
Size
([
0
])):
num_blocks
=
kv_cache
.
shape
[
1
]
block_size
=
kv_cache
.
shape
[
2
]
//
(
self
.
num_kv_heads
*
self
.
head_size
)
k_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
v_scale
=
torch
.
empty
((
self
.
num_kv_heads
,
num_blocks
*
block_size
),
dtype
=
torch
.
float32
,
device
=
kv_cache
.
device
)
self
.
aiter_kv_scales_initialized
=
True
k_scale
.
fill_
(
layer
.
_k_scale
.
item
())
v_scale
.
fill_
(
layer
.
_v_scale
.
item
())
layer
.
_k_scale
=
k_scale
layer
.
_v_scale
=
v_scale
# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if
self
.
attn_type
not
in
[
AttentionType
.
ENCODER
,
AttentionType
.
ENCODER_ONLY
]
and
kv_cache
.
numel
()
>
0
:
key_cache
,
value_cache
=
paged_attn
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
key
is
not
None
and
value
is
not
None
:
# Reshape the input keys and values and store them in the
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
paged_attn
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
attn_metadata
.
cross_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
if
self
.
attn_type
!=
AttentionType
.
ENCODER
:
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
elif
self
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
# For encoder-only models, all tokens are processed in one go
num_prefill_tokens
=
query
.
shape
[
0
]
else
:
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
# For encoder-only and encoder models,
# we process all tokens at once
# For decoder and encoder-decoder,
# we may need to limit key/value to prefill tokens
if
key
is
not
None
and
value
is
not
None
\
and
self
.
attn_type
not
in
[
AttentionType
.
ENCODER_DECODER
,
AttentionType
.
ENCODER_ONLY
]:
key
=
key
[:
num_prefill_tokens
]
value
=
value
[:
num_prefill_tokens
]
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
# normal attention and DECODER
if
self
.
attn_type
==
AttentionType
.
DECODER
and
(
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
):
(
query_seq_start_loc
,
query_max_seq_len
,
key_seq_start_loc
,
key_max_seq_len
,
seq_lens
,
causal_mask
)
=
(
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
prefill_meta
.
seq_start_loc
,
prefill_meta
.
max_prefill_seq_len
,
attn_metadata
.
seq_lens
,
True
)
# prefix-enabled attention and ENCODER/ENCODER_DECODER
else
:
(
query_seq_start_loc
,
query_max_seq_len
,
key_seq_start_loc
,
key_max_seq_len
,
seq_lens
,
causal_mask
)
=
_get_seq_len_block_table_args
(
prefill_meta
,
self
.
attn_type
)
# Prompt run.
if
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
attn_masks
=
None
if
self
.
use_triton_flash_attn
:
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
use_fp8_scales
=
(
layer
.
_q_scale
and
layer
.
_k_scale
and
layer
.
_v_scale
and
layer
.
_prob_scale
and
(
self
.
kv_cache_dtype
==
"fp8"
or
self
.
force_fp8_attention
))
full_scales
=
(
layer
.
_q_scale
.
item
(),
layer
.
_k_scale
.
item
(),
layer
.
_v_scale
.
item
(),
layer
.
_prob_scale
.
item
())
if
use_fp8_scales
else
None
self
.
triton_attn_func
(
query
,
key
,
value
,
output
[:
num_prefill_tokens
],
query_seq_start_loc
,
key_seq_start_loc
,
query_max_seq_len
,
key_max_seq_len
,
causal_mask
,
self
.
scale
,
attn_masks
[
0
][
None
]
if
attn_masks
is
not
None
else
None
,
full_scales
,
output_scale
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
alibi_slopes
is
not
None
:
attn_masks
=
_make_alibi_bias
(
self
.
alibi_slopes
,
query
.
dtype
,
attn_metadata
.
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
# sdpa math backend attention
self
.
sdpa_attn_func
(
query
,
key
,
value
,
output
[:
num_prefill_tokens
],
query_seq_start_loc
,
num_prefill_tokens
,
self
.
num_heads
,
self
.
head_size
,
self
.
scale
,
attn_masks
,
)
else
:
# upstream FA does not support an output arg, copy
output
[:
num_prefill_tokens
]
=
self
.
fa_attn_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
query_seq_start_loc
,
cu_seqlens_k
=
key_seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
key_max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
causal_mask
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
softcap
=
self
.
logits_soft_cap
,
)
else
:
# prefix-enabled attention -
# not applicable for encoder-only models
if
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
:
output
[:
num_prefill_tokens
]
=
paged_attn
.
forward_prefix
(
query
,
key
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Skip decode phase for encoder-only models
if
(
decode_meta
:
=
attn_metadata
.
decode_metadata
)
and
(
self
.
attn_type
!=
AttentionType
.
ENCODER_ONLY
):
# Decoding run.
# Whether to use rocm custom paged attention or not
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
use_custom
=
use_rocm_custom_paged_attention
(
decode_query
.
dtype
,
head_size
,
block_size
,
gqa_ratio
,
decode_meta
.
max_decode_seq_len
,
self
.
sliding_window
,
self
.
kv_cache_dtype
,
self
.
alibi_slopes
)
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
max_encoder_seq_len
)
assert
max_seq_len
is
not
None
max_num_partitions
=
(
(
max_seq_len
+
_PARTITION_SIZE_ROCM
-
1
)
//
_PARTITION_SIZE_ROCM
)
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
query
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
query_start_loc
=
None
ops
.
paged_attention_rocm
(
output
[
num_prefill_tokens
:],
exp_sums
,
max_logits
,
tmp_output
,
decode_query
,
key_cache
,
value_cache
,
self
.
num_kv_heads
,
self
.
scale
,
decode_meta
.
block_tables
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
cross_block_tables
,
decode_meta
.
seq_lens_tensor
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
encoder_seq_lens_tensor
,
query_start_loc
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
output_scale
,
)
else
:
# PagedAttention does not support fused quant, manually quantize
if
output_scale
is
None
:
out_pa
=
output
[
num_prefill_tokens
:]
else
:
out_pa
=
torch
.
empty_like
(
output
[
num_prefill_tokens
:],
dtype
=
query
.
dtype
)
out_pa
[:]
=
paged_attn
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
decode_meta
.
block_tables
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
cross_block_tables
,
decode_meta
.
seq_lens_tensor
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
encoder_seq_lens_tensor
,
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
max_encoder_seq_len
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Manually perform quantization
if
output_scale
is
not
None
:
out_uq
=
out_pa
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
out_q
=
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
ops
.
scaled_fp8_quant
(
out_uq
,
output_scale
,
output
=
out_q
[
num_prefill_tokens
:])
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_sdpa_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
start
=
0
assert
output
.
shape
==
(
num_tokens
,
num_heads
,
head_size
)
assert
output
.
dtype
==
query
.
dtype
assert
output
.
device
==
query
.
device
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
with
torch
.
nn
.
attention
.
sdpa_kernel
(
torch
.
nn
.
attention
.
SDPBackend
.
MATH
):
sub_out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
dropout_p
=
0.0
,
is_causal
=
attn_masks
is
None
,
attn_mask
=
attn_masks
[
i
]
if
attn_masks
else
None
,
scale
=
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
return
output
vllm/attention/backends/triton_mla.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
List
,
Optional
,
Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
)
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
class
TritonMLABackend
(
MLACommonBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"TRITON_MLA"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"TritonMLAImpl"
]:
return
TritonMLAImpl
class
TritonMLAImpl
(
MLACommonImpl
[
MLACommonMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
kv_sharing_target_layer_name
:
Optional
[
str
],
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
logits_soft_cap
,
attn_type
,
kv_sharing_target_layer_name
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"TritonMLA with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
B
=
q_nope
.
shape
[
0
]
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
torch
.
zeros
(
B
,
self
.
num_heads
,
self
.
kv_lora_rank
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
num_kv_splits
=
4
# TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits
=
torch
.
empty
(
(
B
,
self
.
num_heads
,
num_kv_splits
,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self
.
kv_lora_rank
+
1
,
),
dtype
=
torch
.
float32
,
device
=
q
.
device
,
)
# Add a head dim of 1
kv_c_and_k_pe_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
2
)
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
)
# Run MQA
decode_attention_fwd
(
q
,
kv_c_and_k_pe_cache
,
kv_c_cache
,
o
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
,
attn_logits
,
num_kv_splits
,
self
.
scale
,
PAGE_SIZE
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/utils.py
View file @
bc6e542d
...
@@ -338,10 +338,9 @@ class CommonAttentionState(AttentionState):
...
@@ -338,10 +338,9 @@ class CommonAttentionState(AttentionState):
# The encoder decoder model works only with XFormers and
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
[
"XFORMERS"
,
"FLASH_ATTN"
,
"ROCM_FLASH"
],
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
"Expected attn_backend name to be either 'XFORMERS',"
\
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'ROCM_FLASH', or 'FLASH_ATTN', but "
\
f
"'FLASH_ATTN', but got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_update_captured_metadata_for_enc_dec_model
(
self
.
_update_captured_metadata_for_enc_dec_model
(
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
batch_size
=
batch_size
,
attn_metadata
=
attn_metadata
)
...
@@ -360,10 +359,9 @@ class CommonAttentionState(AttentionState):
...
@@ -360,10 +359,9 @@ class CommonAttentionState(AttentionState):
# The encoder decoder model works only with XFormers and
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
# Flash Attention backend. Assert the same.
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
assert
self
.
runner
.
attn_backend
.
get_name
()
in
\
[
"XFORMERS"
,
"FLASH_ATTN"
,
"ROCM_FLASH"
],
\
[
"XFORMERS"
,
"FLASH_ATTN"
],
\
f
"Expected attn_backend name to be either 'XFORMERS',"
\
f
"Expected attn_backend name to be either 'XFORMERS' or "
\
f
"'ROCM_FLASH', or 'FLASH_ATTN', but "
\
f
"'FLASH_ATTN', but got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
f
"got '
{
self
.
runner
.
attn_backend
.
get_name
()
}
'"
self
.
_add_additional_input_buffers_for_enc_dec_model
(
self
.
_add_additional_input_buffers_for_enc_dec_model
(
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
attn_metadata
=
attn_metadata
,
input_buffers
=
input_buffers
)
return
input_buffers
return
input_buffers
...
...
vllm/attention/backends/xformers.py
deleted
100644 → 0
View file @
af7dfb0d
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with xFormers and PagedAttention."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
(
AttentionBias
,
BlockDiagonalCausalMask
,
BlockDiagonalMask
,
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
,
get_num_prefill_decode_query_kv_tokens
,
get_seq_len_block_table_args
,
is_all_cross_attn_metadata_set
,
is_all_encoder_attn_metadata_set
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
XFormersBackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"XFORMERS"
@
staticmethod
def
get_impl_cls
()
->
Type
[
"XFormersImpl"
]:
return
XFormersImpl
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"AttentionMetadata"
]:
return
XFormersMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"XFormersMetadataBuilder"
]:
return
XFormersMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"CommonAttentionState"
]:
return
CommonAttentionState
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
PagedAttention
.
get_kv_cache_shape
(
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
PagedAttention
.
copy_blocks
(
kv_caches
,
src_to_dists
)
@
dataclass
class
XFormersMetadata
(
AttentionMetadata
,
PagedAttentionMetadata
):
"""Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ----------------------|
# |-- query_len ---|
# seq_lens stored as a tensor.
seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
# FIXME: It is for flash attn.
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len
:
int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len
:
int
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph
:
bool
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens
:
Optional
[
List
[
int
]]
=
None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum query length in the batch. None for decoding.
max_query_len
:
Optional
[
int
]
=
None
# Max number of query tokens among request in the batch.
max_decode_query_len
:
Optional
[
int
]
=
None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Self-attention prefill/decode metadata cache
_cached_prefill_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
_cached_decode_metadata
:
Optional
[
"XFormersMetadata"
]
=
None
# Begin encoder attn & enc/dec cross-attn fields...
# Encoder sequence lengths representation
encoder_seq_lens
:
Optional
[
List
[
int
]]
=
None
encoder_seq_lens_tensor
:
Optional
[
torch
.
Tensor
]
=
None
# FIXME: It is for flash attn.
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
encoder_seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
# Maximum sequence length among encoder sequences
max_encoder_seq_len
:
Optional
[
int
]
=
None
# Number of tokens input to encoder
num_encoder_tokens
:
Optional
[
int
]
=
None
# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping
:
Optional
[
torch
.
Tensor
]
=
None
cross_block_tables
:
Optional
[
torch
.
Tensor
]
=
None
def
__post_init__
(
self
):
# Set during the execution of the first attention op.
# It is a list because it is needed to set per prompt
# when alibi slopes is used. It is because of the limitation
# from xformer API.
# will not appear in the __repr__ and __init__
self
.
attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
encoder_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
self
.
cross_attn_bias
:
Optional
[
List
[
AttentionBias
]]
=
None
@
property
def
is_all_encoder_attn_metadata_set
(
self
):
'''
All attention metadata required for encoder attention is set.
'''
return
is_all_encoder_attn_metadata_set
(
self
)
@
property
def
is_all_cross_attn_metadata_set
(
self
):
'''
All attention metadata required for enc/dec cross-attention is set.
Superset of encoder attention required metadata.
'''
return
is_all_cross_attn_metadata_set
(
self
)
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
if
self
.
_cached_prefill_metadata
is
not
None
:
# Recover cached prefill-phase attention
# metadata structure
return
self
.
_cached_prefill_metadata
assert
((
self
.
seq_lens
is
not
None
)
or
(
self
.
encoder_seq_lens
is
not
None
))
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
query_start_loc
=
(
None
if
self
.
query_start_loc
is
None
else
self
.
query_start_loc
[:
self
.
num_prefills
+
1
])
seq_start_loc
=
(
None
if
self
.
seq_start_loc
is
None
else
self
.
seq_start_loc
[:
self
.
num_prefills
+
1
])
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[:
self
.
num_prefill_tokens
])
seq_lens
=
(
None
if
self
.
seq_lens
is
None
else
self
.
seq_lens
[:
self
.
num_prefills
])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[:
self
.
num_prefills
])
context_lens_tensor
=
(
None
if
self
.
context_lens_tensor
is
None
else
self
.
context_lens_tensor
[:
self
.
num_prefills
])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[:
self
.
num_prefills
])
# Construct & cache prefill-phase attention metadata structure
self
.
_cached_prefill_metadata
=
XFormersMetadata
(
num_prefills
=
self
.
num_prefills
,
num_prefill_tokens
=
self
.
num_prefill_tokens
,
num_decode_tokens
=
0
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
self
.
multi_modal_placeholder_index_maps
,
enable_kv_scales_calculation
=
self
.
enable_kv_scales_calculation
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
self
.
max_query_len
,
max_prefill_seq_len
=
self
.
max_prefill_seq_len
,
max_decode_seq_len
=
0
,
query_start_loc
=
query_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
return
self
.
_cached_prefill_metadata
@
property
def
decode_metadata
(
self
)
->
Optional
[
"XFormersMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
if
self
.
_cached_decode_metadata
is
not
None
:
# Recover cached decode-phase attention
# metadata structure
return
self
.
_cached_decode_metadata
assert
((
self
.
seq_lens_tensor
is
not
None
)
or
(
self
.
encoder_seq_lens_tensor
is
not
None
))
# Compute some attn_metadata fields which default to None
slot_mapping
=
(
None
if
self
.
slot_mapping
is
None
else
self
.
slot_mapping
[
self
.
num_prefill_tokens
:])
seq_lens_tensor
=
(
None
if
self
.
seq_lens_tensor
is
None
else
self
.
seq_lens_tensor
[
self
.
num_prefills
:])
block_tables
=
(
None
if
self
.
block_tables
is
None
else
self
.
block_tables
[
self
.
num_prefills
:])
# Construct & cache decode-phase attention metadata structure
self
.
_cached_decode_metadata
=
XFormersMetadata
(
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
self
.
num_decode_tokens
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
None
,
enable_kv_scales_calculation
=
True
,
seq_lens_tensor
=
seq_lens_tensor
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
self
.
max_decode_seq_len
,
block_tables
=
block_tables
,
use_cuda_graph
=
self
.
use_cuda_graph
,
# Begin encoder & cross attn fields below...
encoder_seq_lens
=
self
.
encoder_seq_lens
,
encoder_seq_lens_tensor
=
self
.
encoder_seq_lens_tensor
,
max_encoder_seq_len
=
self
.
max_encoder_seq_len
,
cross_slot_mapping
=
self
.
cross_slot_mapping
,
cross_block_tables
=
self
.
cross_block_tables
)
# Batch may be composed of prefill|decodes, adjust query start indices
# to refer to the start of decodes when the two are split apart.
# E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
if
self
.
_cached_decode_metadata
.
query_start_loc
is
not
None
:
qs
=
self
.
_cached_decode_metadata
.
query_start_loc
self
.
_cached_decode_metadata
.
query_start_loc
=
qs
-
qs
[
0
]
return
self
.
_cached_decode_metadata
def
_get_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_type
:
str
,
)
->
Optional
[
AttentionBias
]:
'''
Extract appropriate attention bias from attention metadata
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
Returns:
* Appropriate attention bias value given the attention type
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
return
attn_metadata
.
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
return
attn_metadata
.
encoder_attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
return
attn_metadata
.
cross_attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
def
_set_attn_bias
(
attn_metadata
:
XFormersMetadata
,
attn_bias
:
List
[
Optional
[
AttentionBias
]],
attn_type
:
str
,
)
->
None
:
'''
Update appropriate attention bias field of attention metadata,
according to attention type.
Arguments:
* attn_metadata: Attention metadata structure associated with attention
* attn_bias: The desired attention bias value
* attn_type: encoder attention, decoder self-attention,
encoder/decoder cross-attention
'''
if
(
attn_type
==
AttentionType
.
DECODER
or
attn_type
==
AttentionType
.
ENCODER_ONLY
):
attn_metadata
.
attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER
:
attn_metadata
.
encoder_attn_bias
=
attn_bias
elif
attn_type
==
AttentionType
.
ENCODER_DECODER
:
attn_metadata
.
cross_attn_bias
=
attn_bias
else
:
raise
AttributeError
(
f
"Invalid attention type
{
str
(
attn_type
)
}
"
)
class
XFormersMetadataBuilder
(
CommonMetadataBuilder
[
XFormersMetadata
]):
_metadata_cls
=
XFormersMetadata
class
XFormersImpl
(
AttentionImpl
[
XFormersMetadata
]):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows:
|<----------------- num_decode_tokens ------------------>|
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens
always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
"""
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
attn_type
:
str
=
AttentionType
.
DECODER
,
kv_sharing_target_layer_name
:
Optional
[
str
]
=
None
,
use_irope
:
bool
=
False
,
)
->
None
:
if
kv_sharing_target_layer_name
is
not
None
:
raise
NotImplementedError
(
"KV sharing is not supported in V0 "
"XFORMERS backend."
)
if
logits_soft_cap
is
not
None
:
logger
.
warning_once
(
"XFormers does not support logits soft cap. "
"Outputs may be slightly off."
)
if
use_irope
:
logger
.
warning_once
(
"Using irope in XFormers is not supported yet, it will fall"
" back to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_kv_heads
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
sliding_window
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
supported_head_sizes
=
PagedAttention
.
get_supported_head_sizes
()
if
head_size
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
self
.
attn_type
=
attn_type
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"XFormersMetadata"
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
output_block_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
For decoder-only models: query, key and value must be non-None.
For encoder/decoder models:
* XFormersImpl.forward() may be invoked for both self- and cross-
attention layers.
* For self-attention: query, key and value must be non-None.
* For cross-attention:
* Query must be non-None
* During prefill, key and value must be non-None; key and value
get cached for use during decode.
* During decode, key and value may be None, since:
(1) key and value tensors were cached during prefill, and
(2) cross-attention key and value tensors do not grow during
decode
A note on how the attn_type (attention type enum) argument impacts
attention forward() behavior:
* DECODER: normal decoder-only behavior;
use decoder self-attention block table
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
Used for encoder branch of encoder-decoder models.
* ENCODER_ONLY: no kv_caching, uses the normal attention
attributes (seq_lens/seq_lens_tensor/max_seq_len).
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
will match encoder sequence lengths, pass encoder sequence
attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len)
Args:
layer: Attention layer instance.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache: KV cache tensor with shape
[2, num_blocks, block_size * num_kv_heads * head_size].
NOTE: kv_cache will be an empty tensor with shape [0]
for profiling run.
attn_metadata: Metadata for attention.
output: Optional output tensor.
output_scale: Optional output scale tensor.
output_block_scale: Optional output block scale tensor.
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if
output_scale
is
not
None
or
output_block_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
" for XFormersImpl"
)
attn_type
=
self
.
attn_type
# Check that appropriate attention metadata attributes are
# selected for the desired attention type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
raise
AttributeError
(
"Encoder attention requires setting "
"encoder metadata attributes."
)
elif
(
attn_type
==
AttentionType
.
ENCODER_DECODER
and
(
not
attn_metadata
.
is_all_cross_attn_metadata_set
)):
raise
AttributeError
(
"Encoder/decoder cross-attention "
"requires setting cross-attention "
"metadata attributes."
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
assert
value
is
not
None
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
else
:
assert
value
is
None
# Self-attention vs. cross-attention will impact
# which KV cache memory-mapping & which
# seqlen datastructures we utilize
if
(
attn_type
!=
AttentionType
.
ENCODER
and
kv_cache
.
numel
()
>
0
):
# KV-cache during decoder-self- or
# encoder-decoder-cross-attention, but not
# during encoder attention.
#
# Even if there are no new key/value pairs to cache,
# we still need to break out key_cache and value_cache
# i.e. for later use by paged attention
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
key
is
not
None
)
and
(
value
is
not
None
):
if
attn_type
==
AttentionType
.
ENCODER_DECODER
:
# Update cross-attention KV cache (prefill-only)
# During cross-attention decode, key & value will be None,
# preventing this IF-statement branch from running
updated_slot_mapping
=
attn_metadata
.
cross_slot_mapping
else
:
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_query_tokens
:]
# QKV for prefill.
query
=
query
[:
num_prefill_query_tokens
]
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_prefill_kv_tokens
]
value
=
value
[:
num_prefill_kv_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_query_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_query_tokens
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
# Prompt run.
if
kv_cache
.
numel
()
==
0
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
# normal attention.
# block tables are empty if the prompt does not have a cached
# prefix.
out
=
self
.
_run_memory_efficient_xformers_forward
(
query
,
key
,
value
,
prefill_meta
,
attn_type
=
attn_type
)
assert
out
.
shape
==
output
[:
num_prefill_query_tokens
].
shape
output
[:
num_prefill_query_tokens
]
=
out
else
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have prefix attention."
)
assert
prefill_meta
.
query_start_loc
is
not
None
assert
prefill_meta
.
max_query_len
is
not
None
# prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache,
# to be addressed separately.
out
=
PagedAttention
.
forward_prefix
(
query
,
key
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
value_cache
,
prefill_meta
.
block_tables
,
prefill_meta
.
query_start_loc
,
prefill_meta
.
seq_lens_tensor
,
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
assert
output
[:
num_prefill_query_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_query_tokens
]
=
out
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_type
!=
AttentionType
.
ENCODER_ONLY
,
(
"Encoder-only models should not have decode metadata."
)
(
seq_lens_arg
,
max_seq_len_arg
,
block_tables_arg
,
)
=
get_seq_len_block_table_args
(
decode_meta
,
False
,
attn_type
)
output
[
num_prefill_query_tokens
:]
=
PagedAttention
.
forward_decode
(
decode_query
,
key_cache
,
value_cache
,
block_tables_arg
,
seq_lens_arg
,
max_seq_len_arg
,
self
.
kv_cache_dtype
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# Reshape the output tensor.
return
output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
)
def
_run_memory_efficient_xformers_forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_metadata
:
XFormersMetadata
,
attn_type
:
str
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
"""Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args:
query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
attn_type: Select attention type, between encoder attention,
decoder self-attention, or encoder/decoder cross-
attention. Defaults to decoder self-attention,
which is the vLLM default generally
"""
original_query
=
query
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query
=
query
.
view
(
query
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
query
.
shape
[
-
1
])
key
=
key
[:,
:,
None
,
:].
expand
(
key
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
key
.
shape
[
-
1
])
value
=
value
[:,
:,
None
,
:].
expand
(
value
.
shape
[
0
],
self
.
num_kv_heads
,
self
.
num_queries_per_kv
,
value
.
shape
[
-
1
])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
attn_bias
=
_get_attn_bias
(
attn_metadata
,
attn_type
)
if
attn_bias
is
None
:
if
self
.
alibi_slopes
is
None
:
# Cross attention block of decoder branch of encoder-decoder
# model uses seq_lens for dec / encoder_seq_lens for enc
if
(
attn_type
==
AttentionType
.
ENCODER_DECODER
):
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Cross-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
,
device
=
query
.
device
)
# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
elif
attn_type
==
AttentionType
.
ENCODER
:
assert
attn_metadata
.
encoder_seq_lens
is
not
None
# Encoder self-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
encoder_seq_lens
,
device
=
query
.
device
)
# Self-attention block of encoder-only model just
# uses the seq_lens directly.
elif
attn_type
==
AttentionType
.
ENCODER_ONLY
:
assert
attn_metadata
.
seq_lens
is
not
None
# Encoder self-attention mask is non-causal
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
,
device
=
query
.
device
)
# Self-attention block of decoder branch just
# uses the seq_lens directly
elif
attn_type
==
AttentionType
.
DECODER
:
assert
attn_metadata
.
seq_lens
is
not
None
# Decoder self-attention mask is causal
attn_bias
=
BlockDiagonalCausalMask
.
from_seqlens
(
attn_metadata
.
seq_lens
,
device
=
query
.
device
)
else
:
raise
ValueError
(
"Unknown AttentionType: %s"
,
attn_type
)
if
self
.
sliding_window
is
not
None
:
attn_bias
=
attn_bias
.
make_local_attention
(
self
.
sliding_window
)
attn_bias
=
[
attn_bias
]
else
:
assert
attn_type
==
AttentionType
.
DECODER
assert
attn_metadata
.
seq_lens
is
not
None
attn_bias
=
_make_alibi_bias
(
self
.
alibi_slopes
,
self
.
num_kv_heads
,
query
.
dtype
,
attn_metadata
.
seq_lens
)
_set_attn_bias
(
attn_metadata
,
attn_bias
,
attn_type
)
# No alibi slopes.
# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if
self
.
alibi_slopes
is
None
:
# Add the batch dimension.
query
=
query
.
unsqueeze
(
0
)
key
=
key
.
unsqueeze
(
0
)
value
=
value
.
unsqueeze
(
0
)
out
=
xops
.
memory_efficient_attention_forward
(
query
,
key
,
value
,
attn_bias
=
attn_bias
[
0
],
p
=
0.0
,
scale
=
self
.
scale
)
return
out
.
view_as
(
original_query
)
# Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
assert
attn_metadata
.
seq_lens
is
not
None
output
=
torch
.
empty_like
(
original_query
)
start
=
0
for
i
,
seq_len
in
enumerate
(
attn_metadata
.
seq_lens
):
end
=
start
+
seq_len
out
=
xops
.
memory_efficient_attention_forward
(
query
[
None
,
start
:
end
],
key
[
None
,
start
:
end
],
value
[
None
,
start
:
end
],
attn_bias
=
attn_bias
[
i
],
p
=
0.0
,
scale
=
self
.
scale
)
# TODO(woosuk): Unnecessary copy. Optimize.
output
[
start
:
end
].
copy_
(
out
.
view_as
(
original_query
[
start
:
end
]))
start
+=
seq_len
return
output
def
_make_alibi_bias
(
alibi_slopes
:
torch
.
Tensor
,
num_kv_heads
:
int
,
dtype
:
torch
.
dtype
,
seq_lens
:
List
[
int
],
)
->
List
[
AttentionBias
]:
attn_biases
:
List
[
AttentionBias
]
=
[]
for
seq_len
in
seq_lens
:
bias
=
torch
.
arange
(
seq_len
,
dtype
=
dtype
)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(seq_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
# Calculate a matrix where each element represents ith element- jth
# element.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
padded_len
=
(
seq_len
+
7
)
//
8
*
8
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
torch
.
empty
(
1
,
# batch size
num_heads
,
seq_len
,
padded_len
,
device
=
alibi_slopes
.
device
,
dtype
=
dtype
,
)[:,
:,
:,
:
seq_len
].
copy_
(
bias
)
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
attn_biases
.
append
(
LowerTriangularMaskWithTensorBias
(
bias
))
return
attn_biases
vllm/config/model.py
View file @
bc6e542d
...
@@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
...
@@ -32,8 +32,7 @@ from vllm.transformers_utils.config import (
from
vllm.transformers_utils.runai_utils
import
(
ObjectStorageModel
,
from
vllm.transformers_utils.runai_utils
import
(
ObjectStorageModel
,
is_runai_obj_uri
)
is_runai_obj_uri
)
from
vllm.transformers_utils.utils
import
maybe_model_redirect
from
vllm.transformers_utils.utils
import
maybe_model_redirect
from
vllm.utils
import
(
STR_DUAL_CHUNK_FLASH_ATTN_VAL
,
LayerBlockType
,
from
vllm.utils
import
LayerBlockType
,
LazyLoader
,
common_broadcastable_dtype
LazyLoader
,
common_broadcastable_dtype
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -1103,10 +1102,6 @@ class ModelConfig:
...
@@ -1103,10 +1102,6 @@ class ModelConfig:
self
.
hf_config
.
dual_chunk_attention_config
[
self
.
hf_config
.
dual_chunk_attention_config
[
"sparse_attention_enabled"
]
=
True
"sparse_attention_enabled"
]
=
True
if
envs
.
VLLM_ATTENTION_BACKEND
!=
STR_DUAL_CHUNK_FLASH_ATTN_VAL
:
raise
ValueError
(
"please set VLLM_ATTENTION_BACKEND to "
f
"
{
STR_DUAL_CHUNK_FLASH_ATTN_VAL
}
"
)
def
verify_with_parallel_config
(
def
verify_with_parallel_config
(
self
,
self
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment