Unverified Commit b95db244 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[v1] Add real sliding window calculation to FlexAttention direct BlockMask building (#26015)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: default avatarbaonudesifeizhai <baonudesifeizhai@gmail.com>
Co-authored-by: default avatarbaonudesifeizhai <baonudesifeizhai@gmail.com>
parent ad9d656b
...@@ -74,6 +74,9 @@ BATCH_SPECS = { ...@@ -74,6 +74,9 @@ BATCH_SPECS = {
), ),
"large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
"large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
"mixed_large": BatchSpec(
seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32]
),
"single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]), "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
"single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]), "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
} }
...@@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [ ...@@ -587,7 +590,14 @@ SLIDING_WINDOW_BACKENDS_TO_TEST = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_spec_name", "batch_spec_name",
["small_decode", "small_prefill", "mixed_medium", "large_decode", "large_prefill"], [
"small_decode",
"small_prefill",
"mixed_medium",
"large_decode",
"large_prefill",
"mixed_large",
],
) )
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"]) @pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4]) @pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property
from typing import ClassVar from typing import ClassVar
import torch import torch
...@@ -315,6 +316,14 @@ class FlexAttentionMetadata: ...@@ -315,6 +316,14 @@ class FlexAttentionMetadata:
transformed_score_mod: _score_mod_signature | None = None transformed_score_mod: _score_mod_signature | None = None
sliding_window: int | None = None sliding_window: int | None = None
@cached_property
def logical_block_ids(self):
return torch.arange(
cdiv(self.max_seq_len, self.block_size),
device=self.block_table.device,
dtype=torch.long,
)
def _convert_physical_to_logical( def _convert_physical_to_logical(
self, self,
request_lookup: torch.Tensor, request_lookup: torch.Tensor,
...@@ -493,6 +502,7 @@ class FlexAttentionMetadata: ...@@ -493,6 +502,7 @@ class FlexAttentionMetadata:
The direct path works as follows: The direct path works as follows:
1. For each query token, fetch blocks from block_table using max_seq_len 1. For each query token, fetch blocks from block_table using max_seq_len
and exclude out of sliding window blocks if needed.
(this fetches more blocks than needed for shorter sequences) (this fetches more blocks than needed for shorter sequences)
2. Group query tokens into chunks of q_block_size 2. Group query tokens into chunks of q_block_size
3. For each group, deduplicate the blocks using unique_static_unsorted 3. For each group, deduplicate the blocks using unique_static_unsorted
...@@ -517,6 +527,23 @@ class FlexAttentionMetadata: ...@@ -517,6 +527,23 @@ class FlexAttentionMetadata:
used_pages = self.block_table[ used_pages = self.block_table[
self.doc_ids, : cdiv(self.max_seq_len, self.block_size) self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
] ]
if self.sliding_window and self.causal:
device = used_pages.device
assert self.doc_ids is not None
token_indices = torch.arange(
self.doc_ids.shape[0], device=device, dtype=torch.long
)
logical_q_idx = (
token_indices
- self.query_start_loc[self.doc_ids]
+ self.decode_offset[self.doc_ids]
)
min_kv_idx = torch.clamp(logical_q_idx - (self.sliding_window - 1), min=0)
min_block_idx = min_kv_idx // self.block_size
sliding_mask = self.logical_block_ids >= min_block_idx[:, None]
used_pages.masked_fill_(~sliding_mask, 0)
used_pages_padded = pad_to_multiple( used_pages_padded = pad_to_multiple(
used_pages, multiple=self.q_block_size, dim=0 used_pages, multiple=self.q_block_size, dim=0
) )
...@@ -785,12 +812,6 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -785,12 +812,6 @@ class FlexAttentionImpl(AttentionImpl):
if attn_metadata.sliding_window != self.sliding_window: if attn_metadata.sliding_window != self.sliding_window:
attn_metadata.sliding_window = self.sliding_window attn_metadata.sliding_window = self.sliding_window
if attn_metadata.direct_build: if attn_metadata.direct_build:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger.warning_once(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded."
)
# update mask mod in attention metadata # update mask mod in attention metadata
attn_metadata.mask_mod = attn_metadata.get_mask_mod() attn_metadata.mask_mod = attn_metadata.get_mask_mod()
attn_metadata.block_mask = attn_metadata._build_block_mask_direct() attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment