Unverified Commit 9d5fa68b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use torch.compile to fuse flash attention decode metadata preparation (#6973)

parent 2c186425
......@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
......@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
)
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
# Optimize cumulative sequence length calculation
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
normal_decode_set_medadata(
metadata,
self.req_to_token,
req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
self.page_size,
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self._update_local_attn_metadata_for_replay(metadata, bs)
elif forward_mode.is_target_verify():
if self.topk <= 1:
......@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc,
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def normal_decode_set_medadata(
metadata,
req_to_token,
req_pool_indices,
strided_indices,
max_seq_pages,
seq_lens,
page_size,
):
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
page_indices = req_to_token[
req_pool_indices[:, None],
strided_indices[:max_seq_pages][None, :],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
metadata.page_table[:, max_seq_pages:].fill_(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment