Unverified Commit 1b9175cb authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745)


Co-authored-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent 92bb49a7
......@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None
max_seq_len_q: int = 0
max_seq_len_k: int = 0
window_size: tuple = (-1, -1)
page_table: torch.Tensor = None
cache_seqlens_int32: torch.Tensor = None
max_seq_len_q: int = 0
class FlashAttentionBackend(AttentionBackend):
......@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
# Create metadata based on forward mode
metadata = FlashAttentionMetadata()
extend_seq_lens = forward_batch.extend_seq_lens
# Get sequence information
seqlens_in_batch = forward_batch.seq_lens
# Precompute int32 version of sequence lengths
......@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
0, batch_size + 1, dtype=torch.int32, device=device
)
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
# Precompute cumulative sequence lengths
if not extend_no_prefix:
if any(forward_batch.extend_prefix_lens_cpu):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
else:
metadata.cu_seqlens_q = metadata.cu_seqlens_k
metadata.max_seq_len_q = seqlens_in_batch.max().item()
metadata.max_seq_len_q = metadata.max_seq_len_k
self.forward_metadata = metadata
def forward_extend(
......@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor],
):
# """Initialize forward metadata for replaying CUDA graph."""
seqlens_in_batch = seq_lens[:bs]
metadata = self.decode_cuda_graph_metadata[bs]
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
# For CPU operations
max_len = seq_lens_cpu[:bs].max().item()
metadata.max_seq_len_k = max_len
# For GPU operations
seq_lens_in_batch = seq_lens[:bs]
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seqlens_in_batch.max().item()
# Only zero out the part out of max_len_k
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
# Then do the copy
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
)
self.forward_decode_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self):
......
......@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
decode_seq_lens = self.seq_lens.cpu()
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment