Unverified Commit 29f64c5f authored by Fynn Schmitt-Ulms's avatar Fynn Schmitt-Ulms Committed by GitHub
Browse files

FlexAttention non-causal support (#40394)


Signed-off-by: default avatarFynn Schmitt-Ulms <fschmitt@redhat.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent eb6661d5
...@@ -315,6 +315,7 @@ def _test_backend_correctness( ...@@ -315,6 +315,7 @@ def _test_backend_correctness(
backend_to_test: list[AttentionBackendEnum | str], backend_to_test: list[AttentionBackendEnum | str],
mask_mod, mask_mod,
*, *,
causal: bool = True,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
block_size: int = 16, block_size: int = 16,
atol: float = 1e-2, atol: float = 1e-2,
...@@ -370,7 +371,7 @@ def _test_backend_correctness( ...@@ -370,7 +371,7 @@ def _test_backend_correctness(
) )
device = torch.device(f"{DEVICE_TYPE}:0") device = torch.device(f"{DEVICE_TYPE}:0")
kv_cache_spec = create_standard_kv_cache_spec(vllm_config) kv_cache_spec = create_standard_kv_cache_spec(vllm_config, attn_type)
# 1. Setup # 1. Setup
batch_size = batch_spec.batch_size batch_size = batch_spec.batch_size
...@@ -453,9 +454,7 @@ def _test_backend_correctness( ...@@ -453,9 +454,7 @@ def _test_backend_correctness(
common_attn_metadata = create_common_attn_metadata( common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device batch_spec, vllm_config.cache_config.block_size, device
) )
if attn_type == AttentionType.ENCODER_ONLY: common_attn_metadata.causal = causal
# For encoder-only, all tokens are prefill tokens
common_attn_metadata.causal = False
# 3. Simulate Paged KV Cache and a realistic slot_mapping # 3. Simulate Paged KV Cache and a realistic slot_mapping
kv_cache = create_and_prepopulate_kv_cache( kv_cache = create_and_prepopulate_kv_cache(
...@@ -736,6 +735,76 @@ def test_sliding_window_encoder_backend_correctness( ...@@ -736,6 +735,76 @@ def test_sliding_window_encoder_backend_correctness(
model, model,
SLIDING_WINDOW_BACKENDS_TO_TEST, SLIDING_WINDOW_BACKENDS_TO_TEST,
sliding_window_mask_mod_fn, sliding_window_mask_mod_fn,
causal=False,
attn_type=AttentionType.ENCODER_ONLY, attn_type=AttentionType.ENCODER_ONLY,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
) )
NON_CAUSAL_BACKENDS_TO_TEST = [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
"FLEX_ATTENTION_SLOW",
]
if current_platform.is_rocm():
NON_CAUSAL_BACKENDS_TO_TEST = [
x
for x in NON_CAUSAL_BACKENDS_TO_TEST
if x is not AttentionBackendEnum.FLASH_ATTN
]
@pytest.mark.parametrize(
"batch_spec_name",
[
"small_decode",
"small_prefill",
"mixed_small",
],
)
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_non_causal_backend_correctness(
default_vllm_config, batch_spec_name: str, model: str
):
"""Test backend's correctness with non-causal (bidirectional) decoder
attention, as used by DFlash speculative decoding."""
def bidirectional_mask_mod(
b: torch.Tensor,
h: torch.Tensor,
q_idx: torch.Tensor,
kv_idx: torch.Tensor,
*,
context_len: int,
):
return q_idx >= 0 # Always True
batch_spec = BATCH_SPECS[batch_spec_name]
LARGE_BLOCK_BACKENDS = (
[AttentionBackendEnum.FLEX_ATTENTION]
if is_torch_equal_or_newer("2.9.0.dev0")
else []
)
SMALL_BLOCK_BACKENDS = [
x for x in NON_CAUSAL_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
]
_test_backend_correctness(
batch_spec,
model,
SMALL_BLOCK_BACKENDS,
bidirectional_mask_mod,
causal=False,
)
if LARGE_BLOCK_BACKENDS:
_test_backend_correctness(
batch_spec,
model,
LARGE_BLOCK_BACKENDS,
bidirectional_mask_mod,
causal=False,
block_size=128,
)
...@@ -21,10 +21,11 @@ from vllm.config.model import ModelDType ...@@ -21,10 +21,11 @@ from vllm.config.model import ModelDType
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import FullAttentionSpec from vllm.v1.kv_cache_interface import EncoderOnlyAttentionSpec, FullAttentionSpec
@dataclass @dataclass
...@@ -142,8 +143,24 @@ def try_backend_includes_kv_cache_update( ...@@ -142,8 +143,24 @@ def try_backend_includes_kv_cache_update(
raise AssertionError("unreachable") from None raise AssertionError("unreachable") from None
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: def create_standard_kv_cache_spec(
"""Create a FullAttentionSpec from ModelParams only.""" vllm_config: VllmConfig,
attn_type: AttentionType = AttentionType.DECODER,
) -> FullAttentionSpec | EncoderOnlyAttentionSpec:
"""Create an AttentionSpec from VllmConfig.
Returns an EncoderOnlyAttentionSpec for encoder-only attention (no KV
cache), and a FullAttentionSpec otherwise.
"""
if attn_type == AttentionType.ENCODER_ONLY:
return EncoderOnlyAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config
),
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
)
return FullAttentionSpec( return FullAttentionSpec(
block_size=vllm_config.cache_config.block_size, block_size=vllm_config.cache_config.block_size,
num_kv_heads=vllm_config.model_config.get_num_kv_heads( num_kv_heads=vllm_config.model_config.get_num_kv_heads(
......
...@@ -36,7 +36,7 @@ from vllm.v1.attention.backend import ( ...@@ -36,7 +36,7 @@ from vllm.v1.attention.backend import (
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec, EncoderOnlyAttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -90,6 +90,10 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -90,6 +90,10 @@ class FlexAttentionBackend(AttentionBackend):
def get_name() -> str: def get_name() -> str:
return "FLEX_ATTENTION" return "FLEX_ATTENTION"
@classmethod
def supports_non_causal(cls) -> bool:
return True
@classmethod @classmethod
def supports_attn_type(cls, attn_type: str) -> bool: def supports_attn_type(cls, attn_type: str) -> bool:
"""FlexAttention supports both decoder and encoder-only attention.""" """FlexAttention supports both decoder and encoder-only attention."""
...@@ -294,6 +298,12 @@ def causal_mask_mod( ...@@ -294,6 +298,12 @@ def causal_mask_mod(
return q_idx >= kv_idx return q_idx >= kv_idx
def bidirectional_mask_mod(
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
):
return q_idx >= 0
# Type alias for the block sparsity hint callable signature. # Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature = Callable[ _block_sparsity_hint_signature = Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor [torch.Tensor, torch.Tensor, int], torch.Tensor
...@@ -364,6 +374,7 @@ class FlexAttentionMetadata: ...@@ -364,6 +374,7 @@ class FlexAttentionMetadata:
block_mask: BlockMask | None = None block_mask: BlockMask | None = None
score_mod: _score_mod_signature | None = None score_mod: _score_mod_signature | None = None
logical_mask_mod: _mask_mod_signature = causal_mask_mod logical_mask_mod: _mask_mod_signature = causal_mask_mod
uses_paged_kv: bool = True
doc_ids: torch.Tensor | None = None doc_ids: torch.Tensor | None = None
direct_build: bool = True direct_build: bool = True
q_block_size: int = 16 q_block_size: int = 16
...@@ -497,7 +508,7 @@ class FlexAttentionMetadata: ...@@ -497,7 +508,7 @@ class FlexAttentionMetadata:
False, False,
) )
return final_mask_mod if self.causal else sliding_window_mask_mod return final_mask_mod if self.uses_paged_kv else sliding_window_mask_mod
def get_prefix_lm_mask_mod(self) -> _mask_mod_signature: def get_prefix_lm_mask_mod(self) -> _mask_mod_signature:
"""Creates the prefix LM mask_mod function for FlexAttention.""" """Creates the prefix LM mask_mod function for FlexAttention."""
...@@ -541,8 +552,7 @@ class FlexAttentionMetadata: ...@@ -541,8 +552,7 @@ class FlexAttentionMetadata:
def get_mask_mod(self): def get_mask_mod(self):
# Stage-1: initialize the base mask_mod # Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder) # (causal mask for decoder or bidirectional mask for encoder)
has_custom_mask = self.logical_mask_mod is not causal_mask_mod if self.uses_paged_kv:
if self.causal or has_custom_mask:
mask_mod = self.get_paged_mask_mod() mask_mod = self.get_paged_mask_mod()
else: else:
mask_mod = self.get_bidirectional_mask_mod() mask_mod = self.get_bidirectional_mask_mod()
...@@ -595,7 +605,7 @@ class FlexAttentionMetadata: ...@@ -595,7 +605,7 @@ class FlexAttentionMetadata:
return transformed_score_mod return transformed_score_mod
def _build_block_mask_direct(self) -> BlockMask: def _build_block_mask_direct(self) -> BlockMask:
"""Direct block mask construction for standard causal attention. """Direct block mask construction for paged KV cache attention.
This method constructs the block mask directly using This method constructs the block mask directly using
BlockMask.from_kv_blocks which is much more efficient than the BlockMask.from_kv_blocks which is much more efficient than the
...@@ -693,7 +703,9 @@ class FlexAttentionMetadata: ...@@ -693,7 +703,9 @@ class FlexAttentionMetadata:
def build_block_mask(self) -> BlockMask: def build_block_mask(self) -> BlockMask:
mask_mod = self.get_mask_mod() mask_mod = self.get_mask_mod()
kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens kv_len = (
self.total_cache_tokens if self.uses_paged_kv else self.num_actual_tokens
)
return create_block_mask_compiled( return create_block_mask_compiled(
mask_mod, mask_mod,
None, None,
...@@ -842,8 +854,16 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -842,8 +854,16 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
offset_tensor = common_attn_metadata.compute_num_computed_tokens() offset_tensor = common_attn_metadata.compute_num_computed_tokens()
offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor) offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor)
uses_paged_kv = not isinstance(self.kv_cache_spec, EncoderOnlyAttentionSpec)
logical_mask_mod = (
bidirectional_mask_mod
if uses_paged_kv and not common_attn_metadata.causal
else causal_mask_mod
)
out = FlexAttentionMetadata( out = FlexAttentionMetadata(
causal=common_attn_metadata.causal, causal=common_attn_metadata.causal,
logical_mask_mod=logical_mask_mod,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
...@@ -863,10 +883,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -863,10 +883,11 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
total_cache_tokens=total_cache_tokens, total_cache_tokens=total_cache_tokens,
decode_offset=offset_tensor, decode_offset=offset_tensor,
num_blocks_per_seq=num_blocks_per_seq, num_blocks_per_seq=num_blocks_per_seq,
uses_paged_kv=uses_paged_kv,
# FIXME(Isotr0py): direct build has issue to build bidirectional # FIXME(Isotr0py): direct build has issue to build bidirectional
# attention block mask for encoder-only models, disable it temporarily. # attention block mask for encoder-only models, disable it temporarily.
# see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053 # see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053
direct_build=(self.direct_build and common_attn_metadata.causal), direct_build=self.direct_build and uses_paged_kv,
q_block_size=self.q_block_size, q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size, kv_block_size=self.kv_block_size,
persistent_kv_indices=self.persistent_kv_indices, persistent_kv_indices=self.persistent_kv_indices,
...@@ -1055,9 +1076,7 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -1055,9 +1076,7 @@ class FlexAttentionImpl(AttentionImpl):
else: else:
attn_metadata.block_mask = attn_metadata.build_block_mask() attn_metadata.block_mask = attn_metadata.build_block_mask()
if not attn_metadata.causal: if self.attn_type == AttentionType.ENCODER_ONLY:
assert self.attn_type == AttentionType.ENCODER_ONLY
query, key_tensor, value_tensor = map( query, key_tensor, value_tensor = map(
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
(query, key, value), (query, key, value),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment