Unverified Commit 2dae104d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor cleanup of fa3 backend (#6999)

parent cef6655b
...@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
"cache_seqlens" "cache_seqlens"
][:bs] ][:bs]
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens)
) )
metadata.max_seq_len_q = self.speculative_num_draft_tokens metadata.max_seq_len_q = self.speculative_num_draft_tokens
...@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs :bs
] ]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) metadata.cache_seqlens_int32.copy_(seq_lens)
num_tokens_per_bs = num_tokens // bs num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs metadata.max_seq_len_q = num_tokens_per_bs
...@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
if self.topk <= 1: if self.topk <= 1:
metadata = self.decode_cuda_graph_metadata[bs]
# When topk = 1, we use the normal decode metadata # When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32.copy_( metadata = self.decode_cuda_graph_metadata[bs]
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32) max_len = seq_lens_cpu.max().item()
) metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
)
)
max_seq_pages = ( max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1 metadata.max_seq_len_k + self.page_size - 1
) // self.page_size ) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][
:max_seq_pages
],
]
page_indices //= self.page_size normal_decode_set_medadata(
metadata.page_table[:, :max_seq_pages].copy_(page_indices) metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
self.speculative_step_id + 1,
self.page_size,
)
else: else:
# When top k > 1, we need two specific draft decode metadata, and then merge states # When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens # 1. The first half of metadata for prefix tokens
metadata = self.draft_decode_metadata_topk_normal[bs] metadata = self.draft_decode_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.topk, already set in capture # metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture # metadata.cu_seqlens_q already set in capture
...@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_num_steps, -1 self.speculative_num_steps, -1
).T.contiguous() ).T.contiguous()
metadata_expand.page_table[: cache_loc.shape[0]].copy_( metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32) cache_loc[:, :decode_length]
) )
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else: else:
...@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = max_len
normal_decode_set_medadata( normal_decode_set_medadata(
metadata, metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"], self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages, max_seq_pages,
seq_lens, seq_lens,
0,
self.page_size, self.page_size,
) )
...@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
if self.topk <= 1: if self.topk <= 1:
metadata = self.target_verify_metadata[bs] metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens)
) )
metadata.max_seq_len_k = ( metadata.max_seq_len_k = (
...@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
# When topk > 1, we need two specific target verify metadata, and then merge states # When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens # 1. The first half of metadata for prefix tokens
metadata = self.target_verify_metadata_topk_normal[bs] metadata = self.target_verify_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture # metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture # metadata.cu_seqlens_q already set in capture
...@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.page_table.copy_( metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order) non_masked_page_table.gather(1, sort_order)
) )
metadata_expand.cache_seqlens_int32.copy_( metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
mask.sum(dim=1).to(torch.int32)
)
metadata_expand.cu_seqlens_k[1:].copy_( metadata_expand.cu_seqlens_k[1:].copy_(
torch.cumsum( torch.cumsum(
metadata_expand.cache_seqlens_int32, metadata_expand.cache_seqlens_int32,
...@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
) )
elif forward_mode.is_draft_extend(): elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs] metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32)) metadata.cache_seqlens_int32.copy_(seq_lens)
metadata.max_seq_len_k = seq_lens_cpu.max().item() metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_( metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
) )
accept_length = spec_info.accept_length[:bs] accept_length = spec_info.accept_length[:bs]
metadata.max_seq_len_q = accept_length.max().item() metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
metadata.cu_seqlens_q[1:].copy_( metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32) torch.cumsum(accept_length, dim=0, dtype=torch.int32)
) )
...@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
req_pool_indices[:, None], req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages], self.draft_extend_metadata["strided_indices"][:max_seq_pages],
] ]
page_indices //= self.page_size metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
if encoder_lens is not None: if encoder_lens is not None:
# Only support encoder size 1 for now # Only support encoder size 1 for now
...@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend: ...@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1): for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps,
# so that they do not need to recompute everything from scratch.
self.attn_backends[i].init_forward_metadata_replay_cuda_graph( self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs, bs,
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
...@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend: ...@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
) )
@torch.compile(dynamic=True, backend=get_compiler_backend()) # @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def normal_decode_set_medadata( def normal_decode_set_medadata(
metadata, cache_seqlens_int32: torch.Tensor,
req_to_token, cu_seqlens_k: torch.Tensor,
req_pool_indices, page_table: torch.Tensor,
strided_indices, req_to_token: torch.Tensor,
max_seq_pages, req_pool_indices: torch.Tensor,
seq_lens, strided_indices: torch.Tensor,
page_size, max_seq_pages: torch.Tensor,
seq_lens: torch.Tensor,
seq_len_delta: int,
page_size: int,
): ):
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32)) cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
page_indices = req_to_token[ page_indices = req_to_token[
req_pool_indices[:, None], req_pool_indices[:, None],
strided_indices[:max_seq_pages][None, :], strided_indices[:max_seq_pages][None, :],
] ]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size) page_table[:, :max_seq_pages].copy_(page_indices // page_size)
metadata.page_table[:, max_seq_pages:].fill_(0)
...@@ -920,19 +920,18 @@ def fast_mla_decode_plan( ...@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
self._page_size = page_size self._page_size = page_size
self._sm_scale = sm_scale self._sm_scale = sm_scale
with self.device as device: try:
try: # Standard version with just the required arguments (no use_profiler)
# Standard version with just the required arguments (no use_profiler) self._cached_module.plan.default(
self._cached_module.plan.default( self._float_workspace_buffer,
self._float_workspace_buffer, self._int_workspace_buffer,
self._int_workspace_buffer, self._pin_memory_int_workspace_buffer,
self._pin_memory_int_workspace_buffer, qo_indptr_cpu,
qo_indptr_cpu, kv_indptr_cpu,
kv_indptr_cpu, kv_len_arr_cpu,
kv_len_arr_cpu, num_heads,
num_heads, head_dim_ckv,
head_dim_ckv, causal,
causal, )
) except Exception as e:
except Exception as e: raise RuntimeError(f"Error in alternate MLA plan: {e}")
raise RuntimeError(f"Error in alternate MLA plan: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment