Unverified Commit e983e432 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add Eagle Speculative Decoding to FA3 Backend (#4951)


Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
Co-authored-by: default avatarBaizhou Zhang <sobereddiezhang@gmail.com>
Co-authored-by: default avatarzcnrex <zcnrex@gmail.com>
parent e9c6ce46
...@@ -45,6 +45,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -45,6 +45,9 @@ class FlashAttentionBackend(AttentionBackend):
self, self,
model_runner: ModelRunner, model_runner: ModelRunner,
skip_prefill: bool = False, skip_prefill: bool = False,
topk=0,
speculative_num_steps=0,
step_id=0,
): ):
super().__init__() super().__init__()
...@@ -63,6 +66,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -63,6 +66,10 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = ( self.use_mla = (
model_runner.model_config.attention_arch == AttentionArch.MLA model_runner.model_config.attention_arch == AttentionArch.MLA
) and (not global_server_args_dict["disable_mla"]) ) and (not global_server_args_dict["disable_mla"])
self.skip_prefill = skip_prefill
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.step_id = step_id
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Initialize forward metadata to cache repetitive calculations.""" """Initialize forward metadata to cache repetitive calculations."""
...@@ -72,37 +79,125 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -72,37 +79,125 @@ class FlashAttentionBackend(AttentionBackend):
# Get sequence information # Get sequence information
seqlens_in_batch = forward_batch.seq_lens seqlens_in_batch = forward_batch.seq_lens
# Precompute int32 version of sequence lengths # Precompute int32 version of sequence lengths
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
batch_size = len(seqlens_in_batch) batch_size = len(seqlens_in_batch)
device = seqlens_in_batch.device device = seqlens_in_batch.device
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# Precompute strided indices
# [0, page_size, 2 * page_size, ...]
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
if forward_batch.forward_mode == ForwardMode.DECODE: if forward_batch.forward_mode == ForwardMode.DECODE:
# Precompute cumulative sequence lengths if self.skip_prefill:
metadata.cu_seqlens_q = torch.arange(
0, batch_size * self.topk + 1, dtype=torch.int32, device=device
)
seq_lens_with_decode = seqlens_in_batch + (self.step_id + 1)
metadata.cache_seqlens_int32 = (
(seq_lens_with_decode).repeat_interleave(self.topk).to(torch.int32)
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.step_id + 1
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table = metadata.page_table.repeat_interleave(
self.topk, dim=0
)
cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1
).T
# Calculate page table indices and cache location indices to update the page table.
batch_indices = torch.arange(
batch_size, device=device
).repeat_interleave(self.topk * (self.step_id + 1))
topk_indices = torch.arange(self.topk, device=device).repeat(
batch_size * (self.step_id + 1)
)
row_indices = batch_indices * self.topk + topk_indices
page_table_col_base_indices = seqlens_in_batch.unsqueeze(
1
) + torch.arange(self.step_id + 1, device=device)
page_table_col_indices = page_table_col_base_indices.view(-1).repeat(
self.topk
)
cache_loc_col_indices = torch.arange(
self.step_id + 1, device=device, dtype=torch.int32
).repeat(batch_size * self.topk)
metadata.page_table[row_indices, page_table_col_indices] = cache_loc[
row_indices, cache_loc_col_indices
].to(torch.int32)
else:
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
elif forward_batch.forward_mode == ForwardMode.TARGET_VERIFY:
draft_token_num = forward_batch.spec_info.draft_token_num
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size * draft_token_num + 1, dtype=torch.int32, device=device
)
aug_seq_lens = (forward_batch.seq_lens + draft_token_num).to(torch.int32)
metadata.cache_seqlens_int32 = aug_seq_lens.repeat_interleave(
forward_batch.spec_info.draft_token_num
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0),
) )
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item() + draft_token_num
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
].repeat_interleave(draft_token_num, dim=0)
aug_cum_len = torch.nn.functional.pad(
torch.cumsum(aug_seq_lens, dim=0, dtype=torch.int32), (1, 0)
)
for idx, single_seq_len in enumerate(aug_seq_lens):
metadata.page_table[
idx * draft_token_num : (idx + 1) * draft_token_num, :single_seq_len
] *= forward_batch.spec_info.custom_mask[
aug_cum_len[idx]
* draft_token_num : aug_cum_len[idx + 1]
* draft_token_num
].view(
draft_token_num, -1
)
metadata.max_seq_len_q = 1
else: else:
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
if any(forward_batch.extend_prefix_lens_cpu): if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0) torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
...@@ -111,6 +206,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -111,6 +206,16 @@ class FlashAttentionBackend(AttentionBackend):
else: else:
metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.cu_seqlens_q = metadata.cu_seqlens_k
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
# Precompute strided indices
# [0, page_size, 2 * page_size, ...]
if self.page_size > 1:
self.strided_indices = torch.arange(
0, metadata.page_table.shape[1], self.page_size, device=self.device
)
metadata.page_table = (
metadata.page_table[:, self.strided_indices] // self.page_size
)
self.forward_metadata = metadata self.forward_metadata = metadata
def forward_extend( def forward_extend(
...@@ -281,8 +386,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -281,8 +386,6 @@ class FlashAttentionBackend(AttentionBackend):
# Pre-reshape query tensor # Pre-reshape query tensor
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
# Run attention with precomputed values
o = flash_attn_with_kvcache( o = flash_attn_with_kvcache(
q=q_reshaped, q=q_reshaped,
k_cache=key_cache, k_cache=key_cache,
...@@ -346,7 +449,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -346,7 +449,11 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
# Initialize fixed size tensors for decode operations if self.speculative_num_steps > 0:
raise NotImplementedError(
"FlashAttentionBackend Spec Decoding does not support CUDA graph yet, stay tuned!"
)
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
# Page table for token mapping (batch_size, max_context_len) # Page table for token mapping (batch_size, max_context_len)
"page_table": torch.zeros( "page_table": torch.zeros(
...@@ -385,7 +492,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -385,7 +492,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
req_pool_indices, : req_pool_indices, :
] ]
if forward_mode == ForwardMode.DECODE: if forward_mode.is_cuda_graph():
# Precompute cumulative sequence lengths # Precompute cumulative sequence lengths
metadata.cu_seqlens_q = torch.arange( metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
...@@ -432,3 +539,66 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -432,3 +539,66 @@ class FlashAttentionBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph.""" """Get the fill value for sequence length in CUDA graph."""
return 0 return 0
class FlashAttentionMultiStepBackend:
def __init__(
self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
):
self.model_runner = model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
FlashAttentionBackend(
model_runner,
skip_prefill=True,
topk=self.topk,
speculative_num_steps=self.speculative_num_steps,
step_id=i,
)
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
...@@ -184,6 +184,19 @@ class EAGLEWorker(TpModelWorker): ...@@ -184,6 +184,19 @@ class EAGLEWorker(TpModelWorker):
self.draft_extend_attn_backend = None self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1 self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionMultiStepBackend,
)
self.draft_attn_backend = FlashAttentionMultiStepBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = False
else: else:
raise ValueError( raise ValueError(
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}" f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment