Unverified Commit 8c47fdfd authored by liangel-02's avatar liangel-02 Committed by GitHub
Browse files

[FlexAttention] allow custom mask mod (#37692)


Signed-off-by: default avatarAngel Li <liangel@meta.com>
parent 54b0578a
...@@ -14,6 +14,7 @@ from tests.v1.attention.utils import ( ...@@ -14,6 +14,7 @@ from tests.v1.attention.utils import (
create_vllm_config, create_vllm_config,
) )
from vllm.v1.attention.backends.flex_attention import ( from vllm.v1.attention.backends.flex_attention import (
BlockSparsityHint,
FlexAttentionMetadataBuilder, FlexAttentionMetadataBuilder,
physical_to_logical_mapping, physical_to_logical_mapping,
) )
...@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks(): ...@@ -223,5 +224,55 @@ def test_physical_to_logical_mapping_handles_reused_blocks():
assert out2[0, 2].item() == 1 assert out2[0, 2].item() == 1
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION,
reason="CUDA not available or PyTorch version < 2.9",
)
def test_block_sparsity_hint_prunes_blocks():
"""Test that BlockSparsityHint prunes KV blocks from the direct build path.
Uses a hint that only keeps the diagonal (q_block == kv_block) to verify
that off-diagonal blocks are excluded from the resulting BlockMask.
"""
device = torch.device("cuda")
vllm_config = create_vllm_config(
model_name="facebook/opt-125m",
block_size=16,
max_model_len=1024,
)
kv_cache_spec = create_standard_kv_cache_spec(vllm_config)
batch_spec = BatchSpec(
seq_lens=[256],
query_lens=[256],
name="test_sparsity_hint",
)
common_attn_metadata = create_common_attn_metadata(
batch_spec, vllm_config.cache_config.block_size, device
)
builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, device)
metadata_no_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_no_hint.block_mask = metadata_no_hint._build_block_mask_direct()
assert metadata_no_hint.block_mask.kv_num_blocks.max().item() > 1
def diagonal_hint(q_block_idx, kv_block_idx, block_size):
return q_block_idx == kv_block_idx
metadata_with_hint = builder.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
metadata_with_hint.block_sparsity_hint = BlockSparsityHint(
hint_fn=diagonal_hint,
)
metadata_with_hint.block_mask = metadata_with_hint._build_block_mask_direct()
assert metadata_with_hint.block_mask.kv_num_blocks.max().item() <= 1
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
"""Attention layer with FlexAttention.""" """Attention layer with FlexAttention."""
import math import math
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import ClassVar from typing import ClassVar, NamedTuple
import torch import torch
import torch._dynamo.decorators import torch._dynamo.decorators
...@@ -294,6 +295,27 @@ def causal_mask_mod( ...@@ -294,6 +295,27 @@ def causal_mask_mod(
return q_idx >= kv_idx return q_idx >= kv_idx
# Type alias for the block sparsity hint callable signature.
_block_sparsity_hint_signature = Callable[
[torch.Tensor, torch.Tensor, int], torch.Tensor
]
class BlockSparsityHint(NamedTuple):
"""This prunes KV blocks from the BlockMask before the flex_attention kernel
is invoked, so that blocks that are fully masked never get loaded.
Use this with custom mask_mods that are sparse to avoid
the kernel iterating over all KV blocks unnecessarily.
Attributes:
hint_fn: (q_block_idx [num_tokens, 1], kv_block_idx [1, num_kv_blocks],
block_size int) -> bool Tensor [num_tokens, num_kv_blocks].
Returns True for block pairs that may contain non-masked elements.
"""
hint_fn: _block_sparsity_hint_signature
@dataclass @dataclass
class FlexAttentionMetadata: class FlexAttentionMetadata:
causal: bool causal: bool
...@@ -335,6 +357,7 @@ class FlexAttentionMetadata: ...@@ -335,6 +357,7 @@ class FlexAttentionMetadata:
transformed_score_mod: _score_mod_signature | None = None transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None sliding_window: int | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
block_sparsity_hint: BlockSparsityHint | None = None
@cached_property @cached_property
def logical_block_ids(self): def logical_block_ids(self):
...@@ -378,7 +401,7 @@ class FlexAttentionMetadata: ...@@ -378,7 +401,7 @@ class FlexAttentionMetadata:
return is_valid, logical_q_idx, logical_kv_idx return is_valid, logical_q_idx, logical_kv_idx
def get_causal_mask_mod(self) -> _mask_mod_signature: def get_paged_mask_mod(self) -> _mask_mod_signature:
"""Creates the mask_mod function for FlexAttention. """Creates the mask_mod function for FlexAttention.
This function creates the combined mask mod function that handles: This function creates the combined mask mod function that handles:
...@@ -504,8 +527,9 @@ class FlexAttentionMetadata: ...@@ -504,8 +527,9 @@ class FlexAttentionMetadata:
def get_mask_mod(self): def get_mask_mod(self):
# Stage-1: initialize the base mask_mod # Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder) # (causal mask for decoder or bidirectional mask for encoder)
if self.causal: has_custom_mask = self.logical_mask_mod is not causal_mask_mod
mask_mod = self.get_causal_mask_mod() if self.causal or has_custom_mask:
mask_mod = self.get_paged_mask_mod()
else: else:
mask_mod = self.get_bidirectional_mask_mod() mask_mod = self.get_bidirectional_mask_mod()
# stage-2: add external mask_mod for special attention during # stage-2: add external mask_mod for special attention during
...@@ -591,7 +615,9 @@ class FlexAttentionMetadata: ...@@ -591,7 +615,9 @@ class FlexAttentionMetadata:
self.doc_ids, : cdiv(self.max_seq_len, self.block_size) self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
] ]
if self.sliding_window and self.causal: custom_hint = self.block_sparsity_hint is not None
if self.sliding_window or custom_hint:
device = used_pages.device device = used_pages.device
assert self.doc_ids is not None assert self.doc_ids is not None
token_indices = torch.arange( token_indices = torch.arange(
...@@ -602,10 +628,24 @@ class FlexAttentionMetadata: ...@@ -602,10 +628,24 @@ class FlexAttentionMetadata:
- self.query_start_loc[self.doc_ids] - self.query_start_loc[self.doc_ids]
+ self.decode_offset[self.doc_ids] + self.decode_offset[self.doc_ids]
) )
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
min_block_idx = min_kv_idx // self.block_size if self.sliding_window:
sliding_mask = self.logical_block_ids >= min_block_idx[:, None] assert self.sliding_window is not None
used_pages.masked_fill_(~sliding_mask, 0) min_kv_idx = torch.clamp(
logical_q_idx - (self.sliding_window - 1), min=0
)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
if custom_hint:
assert self.block_sparsity_hint is not None
q_block_idx = logical_q_idx // self.block_size
hint_mask = self.block_sparsity_hint.hint_fn(
q_block_idx[:, None],
self.logical_block_ids[None, :],
self.block_size,
)
used_pages.masked_fill_(~hint_mask, 0)
used_pages_padded = pad_to_multiple( used_pages_padded = pad_to_multiple(
used_pages, multiple=self.q_block_size, dim=0 used_pages, multiple=self.q_block_size, dim=0
...@@ -660,11 +700,6 @@ class FlexAttentionMetadata: ...@@ -660,11 +700,6 @@ class FlexAttentionMetadata:
self.mask_mod = self.get_mask_mod() self.mask_mod = self.get_mask_mod()
self.transformed_score_mod = self.get_transformed_score_mod() self.transformed_score_mod = self.get_transformed_score_mod()
if self.direct_build and self.causal:
self.block_mask = self._build_block_mask_direct()
else:
self.block_mask = self.build_block_mask()
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
def __init__( def __init__(
...@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -770,6 +805,8 @@ class FlexAttentionImpl(AttentionImpl):
alibi_slopes: torch.Tensor | None alibi_slopes: torch.Tensor | None
logits_soft_cap: float | None logits_soft_cap: float | None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
logical_mask_mod: _mask_mod_signature | None = None
block_sparsity_hint: BlockSparsityHint | None = None
def __init__( def __init__(
self, self,
...@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -907,8 +944,25 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata.mask_mod = attn_metadata.get_mask_mod() attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True needs_rebuild_block_mask = True
if needs_rebuild_block_mask: layer_mask_mod = getattr(layer, "logical_mask_mod", None)
if attn_metadata.direct_build and attn_metadata.causal: if (
layer_mask_mod is not None
and attn_metadata.logical_mask_mod is not layer_mask_mod
):
attn_metadata.logical_mask_mod = layer_mask_mod
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
needs_rebuild_block_mask = True
layer_hint = getattr(layer, "block_sparsity_hint", None)
if (
layer_hint is not None
and attn_metadata.block_sparsity_hint is not layer_hint
):
attn_metadata.block_sparsity_hint = layer_hint
needs_rebuild_block_mask = True
if needs_rebuild_block_mask or attn_metadata.block_mask is None:
if attn_metadata.direct_build:
attn_metadata.block_mask = attn_metadata._build_block_mask_direct() attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
else: else:
attn_metadata.block_mask = attn_metadata.build_block_mask() attn_metadata.block_mask = attn_metadata.build_block_mask()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment