Unverified Commit 188f0955 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add Speculative Decoding Eagle3 topk > 1 (#5318)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
Co-authored-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent eef9433b
......@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
......@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
# Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query
max_seq_len_q: int = 0
max_seq_len_q: int = 1
# Maximum sequence length for key
max_seq_len_k: int = 0
# Cumulative sequence lengths for query
......@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b)
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
@torch._dynamo.disable()
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
return merge_state_v2(o, s_a, o_exp, s_b)
class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation.
......@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
), "Sliding window and cross attention are not supported together"
self.forward_metadata: FlashAttentionMetadata = None
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.decode_cuda_graph_metadata = {}
......@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.topk = topk
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens
......@@ -336,14 +344,107 @@ class FlashAttentionBackend(AttentionBackend):
if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode
if forward_batch.spec_info is not None:
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
else:
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
decode_length = self.speculative_step_id + 1
metadata_expand.cache_seqlens_int32 = torch.full(
(seqlens_in_batch.numel() * self.topk,),
decode_length,
device=device,
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
dtype=torch.int32,
device=device,
)
metadata_expand.cu_seqlens_k = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=device,
)
cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1
).T.contiguous()
metadata_expand.page_table = (
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
self.forward_metadata_spec_decode_expand = metadata_expand
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1)
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
......@@ -357,44 +458,101 @@ class FlashAttentionBackend(AttentionBackend):
self._init_local_attn_metadata(metadata, device)
else:
# Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
0,
batch_size * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
forward_batch.seq_lens_cpu.max().item()
+ self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
metadata_expand = FlashAttentionMetadata()
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
+ 1,
dtype=torch.int32,
device=device,
)
# create expand page table
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(
forward_batch.seq_lens.numel(), -1
) + forward_batch.seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
forward_batch.seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
mask = forward_batch.spec_info.custom_mask[
mask_extraction_indices
].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
# shift table indices to avoid padding
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
# [8, 9, 10], [1, 1, 0],
# [8, 9, 10]] [1, 0, 1]]
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
# [8, 9, 0], [8, 9, 10],
# [8, 0, 10]] [8, 10, 9]]
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
# Build keys: if an entry is valid (mask==True), keep its original index;
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
self.forward_metadata_spec_decode_expand = metadata_expand
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
......@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
and (hasattr(layer, "use_irope") and layer.use_irope)
)
# We do cascade attention for Target Verify with topk > 1
use_cascade_attn = (
forward_batch.forward_mode.is_target_verify() and self.topk > 1
)
# Get the appropriate page table based on whether we're using local attention
if use_local_attn:
local_metadata = metadata.local_attn_metadata
......@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
o = flash_attn_with_kvcache(
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
......@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=causal,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
if (
not global_server_args_dict["disable_chunked_prefix_cache"]
......@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
......@@ -638,13 +830,44 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def forward_decode(
self,
......@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
use_local_attention = (
self.attention_chunk_size is not None and local_attn_metadata is not None
)
# We do cascade attention for Draft Decode with topk > 1
use_cascade_attn = self.topk > 1
# Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
......@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
v_descale=v_descale,
)
else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
)
# Default: single-token self-attention
o = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=metadata.page_table,
cache_seqlens=metadata.cache_seqlens_int32,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
......@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q
o = flash_attn_with_kvcache(
result = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
......@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int):
......@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations.
"""
# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange(
......@@ -840,24 +1136,136 @@ class FlashAttentionBackend(AttentionBackend):
),
}
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
# This is used by draft decode's first half of metadata when topk > 1
if self.topk > 1:
self.draft_decode_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
# This is used by draft decode's second half of metadata when topk > 1
decode_length = self.speculative_step_id + 1
self.draft_decode_metadata_topk_expand = {
"cache_seqlens": torch.full(
(max_bs * self.topk,),
decode_length,
device=self.device,
dtype=torch.int32,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.arange(
0,
max_bs * self.topk * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.topk,
decode_length,
dtype=torch.int32,
device=self.device,
),
}
if (
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
self.target_verify_metadata_topk_expand = {
"cache_seqlens": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
}
self.encoder_metadata = {
"encoder_page_table": torch.zeros(
......@@ -886,28 +1294,78 @@ class FlashAttentionBackend(AttentionBackend):
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
# metadata_expand is needed for Spec Decoding when top k > 1
metadata_expand = FlashAttentionMetadata()
device = seq_lens.device
if forward_mode.is_decode_or_idle():
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
: bs + 1
]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
if self.topk <= 1:
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens"
][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode"
][req_pool_indices, :]
self.decode_cuda_graph_metadata[bs] = metadata
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.draft_decode_metadata_topk_normal[
"page_table"
][req_pool_indices, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_expand["cache_seqlens"][
: bs * self.topk
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = (
self.speculative_step_id + 1
) # , do this in replay
metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1
]
)
metadata_expand.cu_seqlens_k = (
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
: bs * self.topk + 1
]
)
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
"page_table"
][: bs * self.topk]
self.draft_decode_metadata_topk_normal[bs] = metadata
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
else:
# Normal Decode
# Get sequence information
......@@ -927,37 +1385,77 @@ class FlashAttentionBackend(AttentionBackend):
metadata.cu_seqlens_q = torch.arange(
0, batch_size + 1, dtype=torch.int32, device=device
)
self.decode_cuda_graph_metadata[bs] = metadata
self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
if self.topk <= 1:
metadata.cache_seqlens_int32 = self.target_verify_metadata[
"cache_seqlens"
][:bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = (
seq_lens.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_q = torch.arange(
0,
bs * self.speculative_num_draft_tokens + 1,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
metadata.page_table = self.target_verify_metadata["page_table"][
req_pool_indices, :
]
self.target_verify_metadata[bs] = metadata
self.target_verify_metadata[bs] = metadata
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
"cache_seqlens"
][:bs]
metadata.max_seq_len_q = self.speculative_num_draft_tokens
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.target_verify_metadata_topk_normal[
"page_table"
][req_pool_indices, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.target_verify_metadata_topk_expand["cache_seqlens"][
: bs * self.speculative_num_draft_tokens
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
"cu_seqlens_q"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
"cu_seqlens_k"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
"page_table"
][: bs * self.speculative_num_draft_tokens]
self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
......@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
]
self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def init_forward_metadata_replay_cuda_graph(
self,
......@@ -986,41 +1485,85 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None,
):
# """Initialize forward metadata for replaying CUDA graph."""
"""Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device
metadata = None
metadata_expand = None
if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None:
# Draft Decode
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
)
if self.topk <= 1:
metadata = self.decode_cuda_graph_metadata[bs]
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
)
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
self.speculative_step_id + 1
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][
:max_seq_pages
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.draft_decode_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
# metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
decode_length = self.speculative_step_id + 1
cache_loc = out_cache_loc.view(
self.speculative_num_steps, -1
).T.contiguous()
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device)
else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode
max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len
......@@ -1045,24 +1588,117 @@ class FlashAttentionBackend(AttentionBackend):
self._init_local_attn_metadata(metadata, device)
elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
if self.topk <= 1:
metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
)
metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
metadata.max_seq_len_k = (
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
)
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.target_verify_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs]
# metadata_expand.max_seq_len_q = 1, already set in capture
# metadata_expand.cu_seqlens_q already set in capture
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
(
seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
# avoid extracting padded seq indices which will be out of boundary
mask_extraction_indices[
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
].fill_(0)
mask = spec_info.custom_mask[mask_extraction_indices].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
self.req_to_token[req_pool_indices, :]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order)
)
metadata_expand.cache_seqlens_int32.copy_(
mask.sum(dim=1).to(torch.int32)
)
metadata_expand.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32,
dim=0,
dtype=torch.int32,
),
(1, 0),
)
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
)
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
if encoder_lens is not None:
# Only support encoder size 1 for now
......@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
......@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
self.model_runner = model_runner
self.topk = topk
self.speculative_num_steps = speculative_num_steps
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
self.topk == 1
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
self.attn_backends = []
for i in range(self.speculative_num_steps):
self.attn_backends.append(
......
......@@ -221,7 +221,16 @@ class ModelRunner:
server_args = self.server_args
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
"""
We auto select the fastest attention backend according to the current offering
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend:
if (
is_hopper_with_cuda_12_3()
......@@ -234,9 +243,7 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
server_args
):
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
......
......@@ -359,7 +359,18 @@ class ServerArgs:
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
logger.info(
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
)
if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
):
logger.info(
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
)
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
......
......@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
return server_args.page_size == 1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
......
......@@ -29,7 +29,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 5),
TestFile("test_fa3.py", 200),
TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55),
......
......@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
self.assertGreater(avg_spec_accept_length, 1.5)
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model = "meta-llama/Llama-3.1-8B-Instruct"
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment