Unverified Commit 29f64c5f authored by Fynn Schmitt-Ulms's avatar Fynn Schmitt-Ulms Committed by GitHub
Browse files

FlexAttention non-causal support (#40394)


Signed-off-by: default avatarFynn Schmitt-Ulms <fschmitt@redhat.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent eb6661d5
......@@ -315,6 +315,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str],
mask_mod,
*,
causal: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16,
atol: float = 1e-2,
......@@ -370,7 +371,7 @@ def _test_backend_correctness(
)
device = torch.device(f"{DEVICE_TYPE}:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config, attn_type)
# 1. Setup
batch_size = batch_spec.batch_size
......@@ -453,9 +454,7 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
if attn_type == AttentionType.ENCODER_ONLY:
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
common_attn_metadata.causal = causal
# 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache(
......@@ -736,6 +735,76 @@ def test_sliding_window_encoder_backend_correctness(
model,
SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn,
causal=False,
attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size,
)
NON_CAUSAL_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
"FLEX_ATTENTION_SLOW",
]
if current_platform.is_rocm():
NON_CAUSAL_BACKENDS_TO_TEST = [
x
for x in NON_CAUSAL_BACKENDS_TO_TEST
if x is not AttentionBackendEnum.FLASH_ATTN
]
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_decode",
"small_prefill",
"mixed_small",
],
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_non_causal_backend_correctness(
default_vllm_config, batch_spec_name: str, model: str
):
"""Test backend's correctness with non-causal (bidirectional) decoder
attention, as used by DFlash speculative decoding."""
def bidirectional_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
):
return q_idx >= 0 # Always True
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = (
[AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
)
SMALL_BLOCK_BACKENDS = [
x for x in NON_CAUSAL_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
bidirectional_mask_mod,
causal=False,
)
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(
batch_spec,
model,
LARGE_BLOCK_BACKENDS,
bidirectional_mask_mod,
causal=False,
block_size=128,
)
......@@ -21,10 +21,11 @@ from vllm.config.model import ModelDType
from vllm.v1.attention.backend import (
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec
@dataclass
......@@ -142,8 +143,24 @@ def try_backend_includes_kv_cache_update(
raise AssertionError("unreachable") from None
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
"""Create a FullAttentionSpec from ModelParams only."""
def create_standard_kv_cache_spec(
vllm_config: VllmConfig,
attn_type: AttentionType = AttentionType.DECODER,
) -> FullAttentionSpec | EncoderOnlyAttentionSpec:
"""Create an AttentionSpec from VllmConfig.
Returns an EncoderOnlyAttentionSpec for encoder-only attention (no KV
cache), and a FullAttentionSpec otherwise.
"""
if attn_type == AttentionType.ENCODER_ONLY:
return EncoderOnlyAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
)
return FullAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
......
......@@ -36,7 +36,7 @@ from vllm.v1.attention.backend import (
AttentionType,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, EncoderOnlyAttentionSpec
logger = init_logger(__name__)
......@@ -90,6 +90,10 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str:
return "FLEX_ATTENTION"
@classmethod
def supports_non_causal(cls) -> bool:
return True
@classmethod
def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention."""
......@@ -294,6 +298,12 @@ def causal_mask_mod(
return q_idx >= kv_idx
def bidirectional_mask_mod(
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
return q_idx >= 0
# Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature = Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor
......@@ -364,6 +374,7 @@ class FlexAttentionMetadata:
block_mask: BlockMask | None = None
score_mod: _score_mod_signature | None = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod
uses_paged_kv: bool = True
doc_ids: torch.Tensor | None = None
direct_build: bool = True
q_block_size: int = 16
......@@ -497,7 +508,7 @@ class FlexAttentionMetadata:
False,
)
return final_mask_mod if self.causal else sliding_window_mask_mod
return final_mask_mod if self.uses_paged_kv else sliding_window_mask_mod
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
"""Creates the prefix LM mask_mod function for FlexAttention."""
......@@ -541,8 +552,7 @@ class FlexAttentionMetadata:
def get_mask_mod(self):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
has_custom_mask = self.logical_mask_mod is not causal_mask_mod
if self.causal or has_custom_mask:
if self.uses_paged_kv:
mask_mod = self.get_paged_mask_mod()
else:
mask_mod = self.get_bidirectional_mask_mod()
......@@ -595,7 +605,7 @@ class FlexAttentionMetadata:
return transformed_score_mod
def _build_block_mask_direct(self) -> BlockMask:
"""Direct block mask construction for standard causal attention.
"""Direct block mask construction for paged KV cache attention.
This method constructs the block mask directly using
BlockMask.from_kv_blocks which is much more efficient than the
......@@ -693,7 +703,9 @@ class FlexAttentionMetadata:
def build_block_mask(self) -> BlockMask:
mask_mod = self.get_mask_mod()
kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens
kv_len = (
self.total_cache_tokens if self.uses_paged_kv else self.num_actual_tokens
)
return create_block_mask_compiled(
mask_mod,
None,
......@@ -842,8 +854,16 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
offset_tensor = common_attn_metadata.compute_num_computed_tokens()
offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor)
uses_paged_kv = not isinstance(self.kv_cache_spec, EncoderOnlyAttentionSpec)
logical_mask_mod = (
bidirectional_mask_mod
if uses_paged_kv and not common_attn_metadata.causal
else causal_mask_mod
)
out = FlexAttentionMetadata(
causal=common_attn_metadata.causal,
logical_mask_mod=logical_mask_mod,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
......@@ -863,10 +883,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor,
num_blocks_per_seq=num_blocks_per_seq,
uses_paged_kv=uses_paged_kv,
# FIXME(Isotr0py): direct build has issue to build bidirectional
# attention block mask for encoder-only models, disable it temporarily.
# see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053
direct_build=(self.direct_build and common_attn_metadata.causal),
direct_build=self.direct_build and uses_paged_kv,
q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size,
persistent_kv_indices=self.persistent_kv_indices,
......@@ -1055,9 +1076,7 @@ class FlexAttentionImpl(AttentionImpl):
else:
attn_metadata.block_mask = attn_metadata.build_block_mask()
if not attn_metadata.causal:
assert self.attn_type == AttentionType.ENCODER_ONLY
if self.attn_type == AttentionType.ENCODER_ONLY:
query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key, value),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment