Unverified Commit 188f0955 authored by Qingquan Song's avatar Qingquan Song Committed by GitHub
Browse files

Add Speculative Decoding Eagle3 topk > 1 (#5318)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
Co-authored-by: default avatarYubo Wang <yubowang2019@gmail.com>
parent eef9433b
...@@ -16,6 +16,7 @@ if TYPE_CHECKING: ...@@ -16,6 +16,7 @@ if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
...@@ -30,7 +31,7 @@ class FlashAttentionMetadata: ...@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
# Sequence lengths for the forward batch # Sequence lengths for the forward batch
cache_seqlens_int32: torch.Tensor = None cache_seqlens_int32: torch.Tensor = None
# Maximum sequence length for query # Maximum sequence length for query
max_seq_len_q: int = 0 max_seq_len_q: int = 1
# Maximum sequence length for key # Maximum sequence length for key
max_seq_len_k: int = 0 max_seq_len_k: int = 0
# Cumulative sequence lengths for query # Cumulative sequence lengths for query
...@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int: ...@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
return -(a // -b) return -(a // -b)
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
@torch._dynamo.disable()
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
return merge_state_v2(o, s_a, o_exp, s_b)
class FlashAttentionBackend(AttentionBackend): class FlashAttentionBackend(AttentionBackend):
"""FlashAttention backend implementation. """FlashAttention backend implementation.
...@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
), "Sliding window and cross attention are not supported together" ), "Sliding window and cross attention are not supported together"
self.forward_metadata: FlashAttentionMetadata = None self.forward_metadata: FlashAttentionMetadata = None
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device self.device = model_runner.device
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
...@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
self.page_size = model_runner.page_size self.page_size = model_runner.page_size
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.topk = model_runner.server_args.speculative_eagle_topk or 0
self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
self.speculative_num_draft_tokens = ( self.speculative_num_draft_tokens = (
model_runner.server_args.speculative_num_draft_tokens model_runner.server_args.speculative_num_draft_tokens
...@@ -336,6 +344,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -336,6 +344,7 @@ class FlashAttentionBackend(AttentionBackend):
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
# Draft Decode # Draft Decode
if forward_batch.spec_info is not None: if forward_batch.spec_info is not None:
if self.topk <= 1:
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = (
seqlens_in_batch + (self.speculative_step_id + 1) seqlens_in_batch + (self.speculative_step_id + 1)
).to(torch.int32) ).to(torch.int32)
...@@ -354,8 +363,57 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -354,8 +363,57 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
else:
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device) metadata_expand = FlashAttentionMetadata()
decode_length = self.speculative_step_id + 1
metadata_expand.cache_seqlens_int32 = torch.full(
(seqlens_in_batch.numel() * self.topk,),
decode_length,
device=device,
dtype=torch.int32,
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() + 1,
dtype=torch.int32,
device=device,
)
metadata_expand.cu_seqlens_k = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=device,
)
cache_loc = forward_batch.out_cache_loc.view(
self.speculative_num_steps, -1
).T.contiguous()
metadata_expand.page_table = (
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
self.forward_metadata_spec_decode_expand = metadata_expand
else: else:
# Normal Decode # Normal Decode
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
...@@ -369,9 +427,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -369,9 +427,10 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(metadata, device)
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
if self.topk <= 1:
metadata.cache_seqlens_int32 = ( metadata.cache_seqlens_int32 = (
forward_batch.seq_lens + self.speculative_num_draft_tokens forward_batch.seq_lens + self.speculative_num_draft_tokens
).to(torch.int32) ).to(torch.int32)
...@@ -388,13 +447,112 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -388,13 +447,112 @@ class FlashAttentionBackend(AttentionBackend):
device=device, device=device,
) )
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
self._init_local_attn_metadata(metadata, device)
else:
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
metadata.max_seq_len_q = self.speculative_num_draft_tokens
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
batch_size * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0), (1, 0),
) )
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
metadata_expand = FlashAttentionMetadata()
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = torch.arange(
0,
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
+ 1,
dtype=torch.int32,
device=device,
)
# create expand page table
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(
forward_batch.seq_lens.numel(), -1
) + forward_batch.seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
forward_batch.seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
mask = forward_batch.spec_info.custom_mask[
mask_extraction_indices
].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
# shift table indices to avoid padding
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
# [8, 9, 10], [1, 1, 0],
# [8, 9, 10]] [1, 0, 1]]
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
# [8, 9, 0], [8, 9, 10],
# [8, 0, 10]] [8, 10, 9]]
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
# Build keys: if an entry is valid (mask==True), keep its original index;
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :
]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
self.forward_metadata_spec_decode_expand = metadata_expand
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed(): elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
...@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
and (hasattr(layer, "use_irope") and layer.use_irope) and (hasattr(layer, "use_irope") and layer.use_irope)
) )
# We do cascade attention for Target Verify with topk > 1
use_cascade_attn = (
forward_batch.forward_mode.is_target_verify() and self.topk > 1
)
# Get the appropriate page table based on whether we're using local attention # Get the appropriate page table based on whether we're using local attention
if use_local_attn: if use_local_attn:
local_metadata = metadata.local_attn_metadata local_metadata = metadata.local_attn_metadata
...@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1) window_size = (-1, -1)
o = flash_attn_with_kvcache( result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
...@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=causal, causal=False if use_cascade_attn else causal,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
) )
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else: else:
if ( if (
not global_server_args_dict["disable_chunked_prefix_cache"] not global_server_args_dict["disable_chunked_prefix_cache"]
...@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim] q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :] q_rope = q_all[:, :, layer.v_head_dim :]
o = flash_attn_with_kvcache(
result = flash_attn_with_kvcache(
q=q_rope, q=q_rope,
k_cache=k_rope_cache, k_cache=k_rope_cache,
v_cache=c_kv_cache, v_cache=c_kv_cache,
...@@ -638,11 +830,42 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -638,11 +830,42 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True,
)
) )
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
...@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
use_local_attention = ( use_local_attention = (
self.attention_chunk_size is not None and local_attn_metadata is not None self.attention_chunk_size is not None and local_attn_metadata is not None
) )
# We do cascade attention for Draft Decode with topk > 1
use_cascade_attn = self.topk > 1
# Calculate window size (can be moved to metadata if layer properties don't change) # Calculate window size (can be moved to metadata if layer properties don't change)
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1 # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
...@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
v_descale=v_descale, v_descale=v_descale,
) )
else: else:
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
)
# Default: single-token self-attention # Default: single-token self-attention
o = flash_attn_with_kvcache( result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), q=q_reshaped,
k_cache=key_cache, k_cache=key_cache,
v_cache=value_cache, v_cache=value_cache,
page_table=metadata.page_table, page_table=page_table,
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=1, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=False if use_cascade_attn else causal,
window_size=window_size, window_size=window_size,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
) )
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
...@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
q_nope = q_all[:, :, : layer.v_head_dim] q_nope = q_all[:, :, : layer.v_head_dim]
q_rope = q_all[:, :, layer.v_head_dim :] q_rope = q_all[:, :, layer.v_head_dim :]
max_seqlen_q = metadata.max_seq_len_q
o = flash_attn_with_kvcache( result = flash_attn_with_kvcache(
q=q_rope, q=q_rope,
k_cache=k_rope_cache, k_cache=k_rope_cache,
v_cache=c_kv_cache, v_cache=c_kv_cache,
...@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
cache_seqlens=metadata.cache_seqlens_int32, cache_seqlens=metadata.cache_seqlens_int32,
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=metadata.cu_seqlens_k, cu_seqlens_k_new=metadata.cu_seqlens_k,
max_seqlen_q=1, max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=False if use_cascade_attn else causal,
softcap=layer.logit_cap, softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
) )
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q_rope,
k_cache=k_rope_cache,
v_cache=c_kv_cache,
qv=q_nope,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
...@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
This creates fixed-size tensors that will be reused during CUDA graph replay This creates fixed-size tensors that will be reused during CUDA graph replay
to avoid memory allocations. to avoid memory allocations.
""" """
# This is being used by normal decode and draft decode when topk == 1
self.decode_cuda_graph_metadata = { self.decode_cuda_graph_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.arange( "cu_seqlens_q": torch.arange(
...@@ -840,11 +1136,75 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -840,11 +1136,75 @@ class FlashAttentionBackend(AttentionBackend):
), ),
} }
self.target_verify_metadata = { # This is used by draft decode's first half of metadata when topk > 1
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), if self.topk > 1:
"cu_seqlens_q": torch.zeros( self.draft_decode_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
step=self.topk,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device max_bs + 1, dtype=torch.int32, device=self.device
), ),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
# This is used by draft decode's second half of metadata when topk > 1
decode_length = self.speculative_step_id + 1
self.draft_decode_metadata_topk_expand = {
"cache_seqlens": torch.full(
(max_bs * self.topk,),
decode_length,
device=self.device,
dtype=torch.int32,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.topk + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.arange(
0,
max_bs * self.topk * decode_length + 1,
step=decode_length,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.topk,
decode_length,
dtype=torch.int32,
device=self.device,
),
}
if (
self.speculative_num_draft_tokens is not None
and self.speculative_num_draft_tokens > 0
):
self.target_verify_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros( "cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device max_bs + 1, dtype=torch.int32, device=self.device
), ),
...@@ -859,6 +1219,54 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -859,6 +1219,54 @@ class FlashAttentionBackend(AttentionBackend):
), ),
} }
if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
self.max_context_len,
dtype=torch.int32,
device=self.device,
),
}
self.target_verify_metadata_topk_expand = {
"cache_seqlens": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
}
self.encoder_metadata = { self.encoder_metadata = {
"encoder_page_table": torch.zeros( "encoder_page_table": torch.zeros(
max_bs, max_bs,
...@@ -886,19 +1294,25 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -886,19 +1294,25 @@ class FlashAttentionBackend(AttentionBackend):
): ):
"""Initialize forward metadata for capturing CUDA graph.""" """Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
# metadata_expand is needed for Spec Decoding when top k > 1
metadata_expand = FlashAttentionMetadata()
device = seq_lens.device device = seq_lens.device
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
if self.topk <= 1:
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
"cache_seqlens" "cache_seqlens"
][:bs] ][:bs]
metadata.max_seq_len_k = seq_lens.max().item() + ( metadata.max_seq_len_k = seq_lens.max().item() + (
self.speculative_step_id + 1 self.speculative_step_id + 1
) )
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
: bs + 1 "cu_seqlens_q"
] ][: bs + 1]
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum( torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
...@@ -908,6 +1322,50 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -908,6 +1322,50 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table = self.decode_cuda_graph_metadata[ metadata.page_table = self.decode_cuda_graph_metadata[
"page_table_draft_decode" "page_table_draft_decode"
][req_pool_indices, :] ][req_pool_indices, :]
self.decode_cuda_graph_metadata[bs] = metadata
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
)
metadata.max_seq_len_q = self.topk
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.draft_decode_metadata_topk_normal[
"page_table"
][req_pool_indices, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.draft_decode_metadata_topk_expand["cache_seqlens"][
: bs * self.topk
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = (
self.speculative_step_id + 1
) # , do this in replay
metadata_expand.cu_seqlens_q = (
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
: bs * self.topk + 1
]
)
metadata_expand.cu_seqlens_k = (
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
: bs * self.topk + 1
]
)
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
"page_table"
][: bs * self.topk]
self.draft_decode_metadata_topk_normal[bs] = metadata
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
else: else:
# Normal Decode # Normal Decode
# Get sequence information # Get sequence information
...@@ -928,10 +1386,12 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -928,10 +1386,12 @@ class FlashAttentionBackend(AttentionBackend):
0, batch_size + 1, dtype=torch.int32, device=device 0, batch_size + 1, dtype=torch.int32, device=device
) )
self.decode_cuda_graph_metadata[bs] = metadata self.decode_cuda_graph_metadata[bs] = metadata
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ if self.topk <= 1:
:bs metadata.cache_seqlens_int32 = self.target_verify_metadata[
] "cache_seqlens"
][:bs]
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
) )
...@@ -958,6 +1418,44 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -958,6 +1418,44 @@ class FlashAttentionBackend(AttentionBackend):
] ]
self.target_verify_metadata[bs] = metadata self.target_verify_metadata[bs] = metadata
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
"cache_seqlens"
][:bs]
metadata.max_seq_len_q = self.speculative_num_draft_tokens
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
"cu_seqlens_q"
][: bs + 1]
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
"cu_seqlens_k"
][: bs + 1]
metadata.page_table = self.target_verify_metadata_topk_normal[
"page_table"
][req_pool_indices, :]
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand.cache_seqlens_int32 = (
self.target_verify_metadata_topk_expand["cache_seqlens"][
: bs * self.speculative_num_draft_tokens
]
)
metadata_expand.max_seq_len_q = 1
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
"cu_seqlens_q"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
"cu_seqlens_k"
][: bs * self.speculative_num_draft_tokens + 1]
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
"page_table"
][: bs * self.speculative_num_draft_tokens]
self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
if encoder_lens is not None: if encoder_lens is not None:
encoder_bs = encoder_lens.numel() encoder_bs = encoder_lens.numel()
...@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
] ]
self.forward_metadata = metadata self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
...@@ -986,17 +1485,21 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -986,17 +1485,21 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_cpu: Optional[torch.Tensor], seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: torch.Tensor = None, out_cache_loc: torch.Tensor = None,
): ):
# """Initialize forward metadata for replaying CUDA graph.""" """Initialize forward metadata for replaying CUDA graph."""
seq_lens = seq_lens[:bs] seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs] seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs] req_pool_indices = req_pool_indices[:bs]
device = seq_lens.device device = seq_lens.device
metadata = None
metadata_expand = None
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
metadata = self.decode_cuda_graph_metadata[bs]
if spec_info is not None: if spec_info is not None:
# Draft Decode # Draft Decode
if self.topk <= 1:
metadata = self.decode_cuda_graph_metadata[bs]
# When topk = 1, we use the normal decode metadata
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32) (seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
) )
...@@ -1013,14 +1516,54 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1013,14 +1516,54 @@ class FlashAttentionBackend(AttentionBackend):
) )
) )
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][
:max_seq_pages
],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# When top k > 1, we need two specific draft decode metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.draft_decode_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
# metadata.max_seq_len_q = self.topk, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[ page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k req_pool_indices, : metadata.max_seq_len_k
] ]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
decode_length = self.speculative_step_id + 1
cache_loc = out_cache_loc.view(
self.speculative_num_steps, -1
).T.contiguous()
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
# TODO: we need to test this part for llama 4 eagle case
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(metadata, device)
else: else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode # Normal Decode
max_len = seq_lens_cpu.max().item() max_len = seq_lens_cpu.max().item()
metadata.max_seq_len_k = max_len metadata.max_seq_len_k = max_len
...@@ -1045,6 +1588,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1045,6 +1588,7 @@ class FlashAttentionBackend(AttentionBackend):
self._init_local_attn_metadata(metadata, device) self._init_local_attn_metadata(metadata, device)
elif forward_mode.is_target_verify(): elif forward_mode.is_target_verify():
if self.topk <= 1:
metadata = self.target_verify_metadata[bs] metadata = self.target_verify_metadata[bs]
metadata.cache_seqlens_int32.copy_( metadata.cache_seqlens_int32.copy_(
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32) (seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
...@@ -1061,9 +1605,101 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1061,9 +1605,101 @@ class FlashAttentionBackend(AttentionBackend):
(1, 0), (1, 0),
) )
) )
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k] max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# When topk > 1, we need two specific target verify metadata, and then merge states
# 1. The first half of metadata for prefix tokens
metadata = self.target_verify_metadata_topk_normal[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
metadata.max_seq_len_k = seq_lens_cpu.max().item()
# metadata.cu_seqlens_q already set in capture
metadata.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
)
page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs]
# metadata_expand.max_seq_len_q = 1, already set in capture
# metadata_expand.cu_seqlens_q already set in capture
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
(
seq_lens + self.speculative_num_draft_tokens
).repeat_interleave(self.speculative_num_draft_tokens),
dim=0,
),
(1, 0),
)[:-1]
mask_extraction_indices = (
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
+ cum_len[:, None]
).view(1, -1)
# avoid extracting padded seq indices which will be out of boundary
mask_extraction_indices[
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
].fill_(0)
mask = spec_info.custom_mask[mask_extraction_indices].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
)
_, sort_order = torch.sort(keys, dim=1)
non_masked_page_table = (
self.req_to_token[req_pool_indices, :]
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order)
)
metadata_expand.cache_seqlens_int32.copy_(
mask.sum(dim=1).to(torch.int32)
)
metadata_expand.cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(
metadata_expand.cache_seqlens_int32,
dim=0,
dtype=torch.int32,
),
(1, 0),
)
)
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
if encoder_lens is not None: if encoder_lens is not None:
# Only support encoder size 1 for now # Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0] metadata.encoder_max_seq_len_k = encoder_lens[0]
...@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table) metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
self.forward_metadata = metadata self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph.""" """Get the fill value for sequence length in CUDA graph."""
...@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend: ...@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
self.model_runner = model_runner self.model_runner = model_runner
self.topk = topk self.topk = topk
self.speculative_num_steps = speculative_num_steps self.speculative_num_steps = speculative_num_steps
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
assert (
self.topk == 1
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
self.attn_backends = [] self.attn_backends = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends.append( self.attn_backends.append(
......
...@@ -221,7 +221,16 @@ class ModelRunner: ...@@ -221,7 +221,16 @@ class ModelRunner:
server_args = self.server_args server_args = self.server_args
if server_args.attention_backend is None: if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention """
We auto select the fastest attention backend according to the current offering
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend: if not self.use_mla_backend:
if ( if (
is_hopper_with_cuda_12_3() is_hopper_with_cuda_12_3()
...@@ -234,9 +243,7 @@ class ModelRunner: ...@@ -234,9 +243,7 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton" "flashinfer" if is_flashinfer_available() else "triton"
) )
else: else:
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one( if is_hopper_with_cuda_12_3():
server_args
):
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
......
...@@ -359,7 +359,18 @@ class ServerArgs: ...@@ -359,7 +359,18 @@ class ServerArgs:
if self.page_size > 1 and self.speculative_eagle_topk > 1: if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1 self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1") logger.info(
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
)
if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
):
logger.info(
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
)
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
# The token generated from the verify step is counted. # The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
......
...@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args): ...@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
return server_args.page_size == 1 return server_args.page_size == 1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args): def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or ( return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None server_args.speculative_eagle_topk is not None
......
...@@ -29,7 +29,7 @@ suites = { ...@@ -29,7 +29,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500), TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 5), TestFile("test_fa3.py", 200),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
......
...@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): ...@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
self.assertGreater(avg_spec_accept_length, 1.5) self.assertGreater(avg_spec_accept_length, 1.5)
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model = "meta-llama/Llama-3.1-8B-Instruct"
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled.""" """Test FlashAttention3 with speculative decode enabled."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment