Unverified Commit 1b9175cb authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745)


Co-authored-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent 92bb49a7
...@@ -29,11 +29,11 @@ class FlashAttentionMetadata: ...@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
cu_seqlens_q: torch.Tensor = None cu_seqlens_q: torch.Tensor = None
cu_seqlens_k: torch.Tensor = None cu_seqlens_k: torch.Tensor = None
max_seq_len_q: int = 0
max_seq_len_k: int = 0 max_seq_len_k: int = 0
window_size: tuple = (-1, -1) window_size: tuple = (-1, -1)
page_table: torch.Tensor = None page_table: torch.Tensor = None
cache_seqlens_int32: torch.Tensor = None cache_seqlens_int32: torch.Tensor = None
max_seq_len_q: int = 0
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
...@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
# Create metadata based on forward mode # Create metadata based on forward mode
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
extend_seq_lens = forward_batch.extend_seq_lens
# Get sequence information # Get sequence information
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
# Precompute int32 version of sequence lengths # Precompute int32 version of sequence lengths
...@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
else: else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
if not extend_no_prefix: if any(forward_batch.extend_prefix_lens_cpu):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
) )
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
else: else:
metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.cu_seqlens_q = metadata.cu_seqlens_k
metadata.max_seq_len_q = seqlens_in_batch.max().item() metadata.max_seq_len_q = metadata.max_seq_len_k
self.forward_metadata = metadata self.forward_metadata = metadata
def forward_extend( def forward_extend(
...@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
): ):
# """Initialize forward metadata for replaying CUDA graph.""" # """Initialize forward metadata for replaying CUDA graph."""
seqlens_in_batch = seq_lens[:bs]
metadata = self.decode_cuda_graph_metadata[bs] metadata = self.decode_cuda_graph_metadata[bs]
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
# For CPU operations
max_len = seq_lens_cpu[:bs].max().item()
metadata.max_seq_len_k = max_len
# For GPU operations
seq_lens_in_batch = seq_lens[:bs]
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
) )
# Precompute maximum sequence length
metadata.max_seq_len_k = seqlens_in_batch.max().item()
# Only zero out the part out of max_len_k # Only zero out the part out of max_len_k
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0) metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
# Then do the copy # Then do the copy
metadata.page_table[:, : metadata.max_seq_len_k].copy_( metadata.page_table[:, : metadata.max_seq_len_k].copy_(
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k] self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
) )
self.forward_decode_metadata = metadata self.forward_decode_metadata = metadata
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
......
...@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if ( if (
global_server_args_dict["enable_flashinfer_mla"] global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"] or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
): ):
decode_seq_lens = self.seq_lens.cpu() decode_seq_lens = self.seq_lens.cpu()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment