Unverified Commit 9d5fa68b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Use torch.compile to fuse flash attention decode metadata preparation (#6973)

parent 2c186425
...@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend ...@@ -11,6 +11,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
...@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1657,30 +1658,22 @@ class FlashAttentionBackend(AttentionBackend):
) )
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported # TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
else: else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode # Normal Decode
metadata = self.decode_cuda_graph_metadata[bs]
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
max_seq_pages = (max_len + self.page_size - 1) // self.page_size
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = max_len
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32) normal_decode_set_medadata(
# Optimize cumulative sequence length calculation metadata,
metadata.cu_seqlens_k[1:].copy_( self.req_to_token,
torch.cumsum(seq_lens, dim=0, dtype=torch.int32) req_pool_indices,
self.decode_cuda_graph_metadata["strided_indices"],
max_seq_pages,
seq_lens,
self.page_size,
) )
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
None, :
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)
self._update_local_attn_metadata_for_replay(metadata, bs) self._update_local_attn_metadata_for_replay(metadata, bs)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
if self.topk <= 1: if self.topk <= 1:
...@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend: ...@@ -2063,3 +2056,23 @@ class FlashAttentionMultiStepBackend:
seq_lens_cpu=forward_batch.seq_lens_cpu, seq_lens_cpu=forward_batch.seq_lens_cpu,
out_cache_loc=forward_batch.out_cache_loc, out_cache_loc=forward_batch.out_cache_loc,
) )
@torch.compile(dynamic=True, backend=get_compiler_backend())
def normal_decode_set_medadata(
metadata,
req_to_token,
req_pool_indices,
strided_indices,
max_seq_pages,
seq_lens,
page_size,
):
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
page_indices = req_to_token[
req_pool_indices[:, None],
strided_indices[:max_seq_pages][None, :],
]
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
metadata.page_table[:, max_seq_pages:].fill_(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment