Unverified Commit c776234b authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

Enable local attention during decode (#5479)

parent 3bface15
......@@ -142,6 +142,16 @@ def make_local_attention_virtual_batches(
seqlens_k_local: Key sequence lengths for local attention
block_table_local: Block table for local attention
"""
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len = seq_lens_np.max()
effective_chunk_size = min(attn_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (effective_chunk_size // page_size) * page_size
if effective_chunk_size < page_size:
effective_chunk_size = page_size
attn_chunk_size = effective_chunk_size
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
......@@ -344,6 +354,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device)
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
......@@ -357,6 +369,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
......@@ -405,49 +419,8 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if (
self.attention_chunk_size is not None
and forward_batch.forward_mode == ForwardMode.EXTEND
):
# Convert tensors to numpy for local attention processing
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
seq_lens_np = metadata.cache_seqlens_int32.cpu().numpy()
# Adjust attention_chunk_size based on the actual sequence length
# to avoid index out of bounds errors
max_seq_len = seq_lens_np.max()
effective_chunk_size = min(self.attention_chunk_size, max_seq_len)
# Make sure effective_chunk_size is divisible by page_size
effective_chunk_size = (
effective_chunk_size // self.page_size
) * self.page_size
if effective_chunk_size < self.page_size:
effective_chunk_size = self.page_size
# Create local attention metadata
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
effective_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
metadata.page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(
device
),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local,
local_max_query_len=seqlens_q_local_np.max(),
local_max_seq_len=seqlens_k_local_np.max(),
)
metadata.local_attn_metadata = local_metadata
if forward_batch.forward_mode == ForwardMode.EXTEND:
self._init_local_attn_metadata(metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
......@@ -704,6 +677,10 @@ class FlashAttentionBackend(AttentionBackend):
# Use precomputed metadata across all layers
metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attention = (
self.attention_chunk_size is not None and local_attn_metadata is not None
)
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
......@@ -738,33 +715,60 @@ class FlashAttentionBackend(AttentionBackend):
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
# Always use non-chunked logic for cross-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.encoder_page_table,
cache_seqlens=metadata.encoder_lens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.encoder_cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=False,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
elif use_local_attention:
# Use chunked (local) attention batching for self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=local_attn_metadata.local_block_table,
cache_seqlens=local_attn_metadata.local_seqused_k,
cu_seqlens_q=local_attn_metadata.local_query_start_loc,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=local_attn_metadata.local_max_query_len,
softmax_scale=layer.scaling,
causal=True,
window_size=(-1, -1),
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
o = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
# Default: single-token self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
)
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......@@ -986,6 +990,8 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]
......@@ -1012,6 +1018,8 @@ class FlashAttentionBackend(AttentionBackend):
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self._init_local_attn_metadata(metadata, device)
else:
# Normal Decode
max_len = seq_lens_cpu.max().item()
......@@ -1035,6 +1043,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self._init_local_attn_metadata(metadata, device)
elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
......@@ -1085,6 +1094,42 @@ class FlashAttentionBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph."""
return 0
def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
if self.attention_chunk_size is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q = metadata.cu_seqlens_q
cache_seqlens_int32 = metadata.cache_seqlens_int32
page_table = metadata.page_table
if cu_seqlens_q is None or cache_seqlens_int32 is None or page_table is None:
metadata.local_attn_metadata = None
return
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seq_lens_np = cache_seqlens_int32.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seq_lens_np,
page_table,
self.page_size,
)
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
local_block_table=block_table_local.to(device),
local_max_query_len=int(seqlens_q_local_np.max()),
local_max_seq_len=int(seqlens_k_local_np.max()),
)
metadata.local_attn_metadata = local_metadata
class FlashAttentionMultiStepBackend:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment