Unverified Commit 2dae104d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Minor cleanup of fa3 backend (#6999)

parent cef6655b
......@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
"cache_seqlens"
][:bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
......@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.cache_seqlens_int32.copy_(seq_lens)
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
......@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
if spec_info is not None:
# Draft Decode
if self.topk <= 1:
metadata = self.decode_cuda_graph_metadata[bs]
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
)
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
)
)
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][
:max_seq_pages
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
normal_decode_set_medadata(
metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
self.speculative_step_id + 1,
self.page_size,
)
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.draft_decode_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
......@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
self.speculative_num_steps, -1
).T.contiguous()
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32)
cache_loc[:, :decode_length]
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else:
......@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_k = max_len
normal_decode_set_medadata(
metadata,
metadata.cache_seqlens_int32,
metadata.cu_seqlens_k,
metadata.page_table,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
0,
self.page_size,
)
......@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
if self.topk <= 1:
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
(seq_lens + self.speculative_num_draft_tokens)
)
metadata.max_seq_len_k = (
......@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.target_verify_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.cache_seqlens_int32.copy_(seq_lens)
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
......@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order)
)
metadata_expand.cache_seqlens_int32.copy_(
mask.sum(dim=1).to(torch.int32)
)
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
metadata_expand.cu_seqlens_k[1:].copy_(
torch.cumsum(
metadata_expand.cache_seqlens_int32,
......@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.cache_seqlens_int32.copy_(seq_lens)
metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
metadata.max_seq_len_q = accept_length.max().item()
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
......@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
if encoder_lens is not None:
# Only support encoder size 1 for now
......@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps,
# so that they do not need to recompute everything from scratch.
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
......@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
# @torch.compile(dynamic=True, backend=get_compiler_backend())
# TODO: fuse these kernels
# NOTE: torch.compile makes it slower in speculative decoding
def normal_decode_set_medadata(
metadata,
req_to_token,
req_pool_indices,
strided_indices,
max_seq_pages,
seq_lens,
page_size,
cache_seqlens_int32: torch.Tensor,
cu_seqlens_k: torch.Tensor,
page_table: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
strided_indices: torch.Tensor,
max_seq_pages: torch.Tensor,
seq_lens: torch.Tensor,
seq_len_delta: int,
page_size: int,
):
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
page_indices = req_to_token[
req_pool_indices[:, None],
strided_indices[:max_seq_pages][None, :],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
metadata.page_table[:, max_seq_pages:].fill_(0)
page_table[:, :max_seq_pages].copy_(page_indices // page_size)
......@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
self._page_size = page_size
self._sm_scale = sm_scale
with self.device as device:
try:
# Standard version with just the required arguments (no use_profiler)
self._cached_module.plan.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
)
except Exception as e:
raise RuntimeError(f"Error in alternate MLA plan: {e}")
try:
# Standard version with just the required arguments (no use_profiler)
self._cached_module.plan.default(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
qo_indptr_cpu,
kv_indptr_cpu,
kv_len_arr_cpu,
num_heads,
head_dim_ckv,
causal,
)
except Exception as e:
raise RuntimeError(f"Error in alternate MLA plan: {e}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment