Unverified Commit 1940cdec authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[Bugfix] Fix Llama4 gibberish output with long context and CUDA graph (#6162)

parent 63484f9f
...@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -913,8 +913,10 @@ class FlashAttentionBackend(AttentionBackend):
# Use precomputed metadata across all layers # Use precomputed metadata across all layers
metadata = self.forward_metadata metadata = self.forward_metadata
local_attn_metadata = getattr(metadata, "local_attn_metadata", None) local_attn_metadata = getattr(metadata, "local_attn_metadata", None)
use_local_attention = ( use_local_attn = (
self.attention_chunk_size is not None and local_attn_metadata is not None self.attention_chunk_size is not None
and local_attn_metadata is not None
and (hasattr(layer, "use_irope") and layer.use_irope)
) )
# We do cascade attention for Draft Decode with topk > 1 # We do cascade attention for Draft Decode with topk > 1
use_cascade_attn = self.topk > 1 use_cascade_attn = self.topk > 1
...@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -970,7 +972,7 @@ class FlashAttentionBackend(AttentionBackend):
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
) )
elif use_local_attention: elif use_local_attn:
# Use chunked (local) attention batching for self-attention # Use chunked (local) attention batching for self-attention
o = flash_attn_with_kvcache( o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
...@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -979,7 +981,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table=local_attn_metadata.local_block_table, page_table=local_attn_metadata.local_block_table,
cache_seqlens=local_attn_metadata.local_seqused_k, cache_seqlens=local_attn_metadata.local_seqused_k,
cu_seqlens_q=local_attn_metadata.local_query_start_loc, cu_seqlens_q=local_attn_metadata.local_query_start_loc,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=None,
max_seqlen_q=local_attn_metadata.local_max_query_len, max_seqlen_q=local_attn_metadata.local_max_query_len,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
...@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1127,7 +1129,6 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
# This is being used by normal decode and draft decode when topk == 1 # This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
...@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1154,6 +1155,34 @@ class FlashAttentionBackend(AttentionBackend):
), ),
} }
# Only allocate local attention buffers if local attention is enabled
# This prevents OOM errors when local attention is not being used
if self.attention_chunk_size is not None:
# Estimate maximum sizes for local attention metadata
max_seq_len = self.max_context_len
page_size = self.page_size or 1
attn_chunk_size = self.attention_chunk_size
max_virtual_batches = max_bs * (
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
)
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
self.decode_cuda_graph_local_attn_metadata = {
"local_query_start_loc": torch.zeros(
max_virtual_batches + 1, dtype=torch.int32, device=self.device
),
"local_seqused_k": torch.zeros(
max_virtual_batches, dtype=torch.int32, device=self.device
),
"local_block_table": torch.zeros(
max_virtual_batches,
max_blocks_per_seq * max_pages_per_block,
dtype=torch.int32,
device=self.device,
),
}
# This is used by draft decode's first half of metadata when topk > 1 # This is used by draft decode's first half of metadata when topk > 1
if self.topk > 1: if self.topk > 1:
self.draft_decode_metadata_topk_normal = { self.draft_decode_metadata_topk_normal = {
...@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1405,6 +1434,21 @@ class FlashAttentionBackend(AttentionBackend):
) )
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
if self.attention_chunk_size is not None:
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
],
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
"local_seqused_k"
],
local_block_table=self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
],
local_max_query_len=1,
local_max_seq_len=1,
)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
if self.topk <= 1: if self.topk <= 1:
metadata.cache_seqlens_int32 = self.target_verify_metadata[ metadata.cache_seqlens_int32 = self.target_verify_metadata[
...@@ -1572,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1572,8 +1616,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.page_table[: cache_loc.shape[0]].copy_( metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32) cache_loc[:, :decode_length].contiguous().to(torch.int32)
) )
# TODO: we need to test this part for llama 4 eagle case # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
self._init_local_attn_metadata(metadata, device)
else: else:
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode # Normal Decode
...@@ -1599,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1599,7 +1642,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0) metadata.page_table[:, max_seq_pages:].fill_(0)
self._init_local_attn_metadata(metadata, device) self._update_local_attn_metadata_for_replay(metadata, bs)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
if self.topk <= 1: if self.topk <= 1:
metadata = self.target_verify_metadata[bs] metadata = self.target_verify_metadata[bs]
...@@ -1755,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1755,6 +1798,7 @@ class FlashAttentionBackend(AttentionBackend):
page_table, page_table,
self.page_size, self.page_size,
) )
local_metadata = FlashAttentionMetadata.LocalAttentionMetadata( local_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device), local_query_start_loc=torch.from_numpy(cu_seqlens_q_local_np).to(device),
local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device), local_seqused_k=torch.from_numpy(seqlens_k_local_np).to(device),
...@@ -1764,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1764,6 +1808,79 @@ class FlashAttentionBackend(AttentionBackend):
) )
metadata.local_attn_metadata = local_metadata metadata.local_attn_metadata = local_metadata
def _update_local_attn_metadata_for_replay(
self, metadata: FlashAttentionMetadata, bs: int
):
"""Update preallocated local attention metadata in-place before CUDA graph replay."""
if self.attention_chunk_size is None:
return
# Access preallocated buffers
local_q_buf = self.decode_cuda_graph_local_attn_metadata[
"local_query_start_loc"
]
local_k_buf = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"]
local_block_buf = self.decode_cuda_graph_local_attn_metadata[
"local_block_table"
]
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"]
# Create a modified version for local attention that only processes the last token
# This mimics the normal decode pattern
cu_seqlens_q = torch.arange(
bs + 1, device=cu_seqlens_q.device, dtype=cu_seqlens_q.dtype
)
seqlens = metadata.cache_seqlens_int32[:bs]
# Slice the page_table to match the batch size and actual sequence length
# This serves three important purposes:
# 1. Ensures we only process the actual batch size (bs) and not the maximum batch size
# 2. Limits the sequence length to prevent processing padding tokens or garbage values
# 3. Prevents zeros in the block table which can cause garbage output during replay
#
# Without this slicing, the pre-allocated page_table may contain zeros or invalid indices
# beyond the actual sequence length, leading to incorrect attention calculations
max_seq_len = int(seqlens.max().item())
sliced_page_table = metadata.page_table[:bs, :max_seq_len]
cu_seqlens_q_np = cu_seqlens_q.cpu().numpy()
seqlens_np = seqlens.cpu().numpy()
(
seqlens_q_local_np,
cu_seqlens_q_local_np,
seqlens_k_local_np,
block_table_local,
) = make_local_attention_virtual_batches(
self.attention_chunk_size,
cu_seqlens_q_np,
seqlens_np,
sliced_page_table,
self.page_size,
)
# Convert back to tensors
device = local_q_buf.device
cu_seqlens_q_local = torch.from_numpy(cu_seqlens_q_local_np).to(device)
seqlens_k_local = torch.from_numpy(seqlens_k_local_np).to(device)
block_table_local = block_table_local.to(device)
# Get sizes
q_len = cu_seqlens_q_local.shape[0]
k_len = seqlens_k_local.shape[0]
b0, b1 = block_table_local.shape
# In-place updates into preallocated tensors and zero out the unused space
local_q_buf[:q_len].copy_(cu_seqlens_q_local)
local_q_buf[q_len:].fill_(0)
local_k_buf[:k_len].copy_(seqlens_k_local)
local_k_buf[k_len:].fill_(0)
local_block_buf[:b0, :b1].copy_(block_table_local)
local_block_buf[b0:, :].fill_(0)
local_block_buf[:b0, b1:].fill_(0)
if metadata.local_attn_metadata is not None:
lam = metadata.local_attn_metadata
lam.local_max_query_len = int(seqlens_q_local_np.max())
lam.local_max_seq_len = int(seqlens_k_local_np.max())
class FlashAttentionMultiStepBackend: class FlashAttentionMultiStepBackend:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment