Unverified Commit 3451fc32 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Feature] Qwen3-Next & FLA: Support MTP topk>1; Up to 6% faster (#11133)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
parent c550ab91
......@@ -330,12 +330,30 @@ def fused_recurrent_gated_delta_rule(
return o, final_state
# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask
# retrieve_parent_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]
# e.g. for a sequence of length 4, the eagle tree attention structure is:
# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i
# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i
# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i
# Tree:
# 0
# / \
# 1 2
# /
# 3
# When calculating token 3's attention, it should attend to token 1 (parent) and token 0 (grand-parent)
# When calculating token 2's attention, it should attend to token 0 (parent)
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"]
is not None,
"HAS_EAGLE_TREE_CUSTOM_ATTN_MASK": lambda args: args[
"retrieve_parent_token_ptr"
]
is not None,
}
)
@triton.jit(do_not_specialize=["T"])
......@@ -352,7 +370,11 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
scale,
intermediate_states_buffer,
cache_steps,
retrieve_parent_token_ptr,
stride_retrieve_parent_token_seq: tl.constexpr,
stride_retrieve_parent_token_token: tl.constexpr,
T,
NP2_T: tl.constexpr,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
......@@ -367,6 +389,7 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
DISABLE_STATE_UPDATE: tl.constexpr, # whether to disable final state update
DISABLE_OUTPUT_CALCULATION: tl.constexpr, # whether to disable output calculation
CACHE_INTERMEDIATE_STATES: tl.constexpr,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
......@@ -393,6 +416,16 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
p_g = g + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
token_indices = tl.arange(0, NP2_T)
mask_retrieve = token_indices < T
retrieve_parent_token_base = (
retrieve_parent_token_ptr
+ (i_n * stride_retrieve_parent_token_seq)
+ token_indices * stride_retrieve_parent_token_token
)
parent_idx_tokens = tl.load(retrieve_parent_token_base, mask_retrieve)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
......@@ -418,6 +451,24 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
step_idx = 0
for _ in range(0, T):
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
# step_idx = 0 should use the b_h from USE_INITIAL_STATE
if step_idx != 0 and cache_idx >= 0:
# when calculating current step's attention, load the state from the parent token
parent_step_idx = tl.sum(
tl.where(token_indices == step_idx, parent_idx_tokens, 0)
)
step_offset = parent_step_idx * HV * K * V
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * HV * K * V
+ step_offset
+ i_hv * K * V
+ o_k[:, None] * V
+ o_v[None, :]
)
b_h = tl.load(cache_ptr, mask=mask_h, other=0).to(tl.float32)
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
......@@ -498,6 +549,7 @@ def fused_recurrent_gated_delta_rule_update_fwd(
disable_output_calculation: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
cache_steps: Optional[int] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
......@@ -516,6 +568,16 @@ def fused_recurrent_gated_delta_rule_update_fwd(
grid = (NK, NV, N * HV)
# prepare retrieve next token buffer strides if provided
if retrieve_parent_token is not None:
stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = (
retrieve_parent_token.stride(0),
retrieve_parent_token.stride(1),
)
else:
stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0
NP2_T = triton.next_power_of_2(T)
fused_recurrent_gated_delta_rule_update_fwd_kernel[grid](
q=q,
k=k,
......@@ -529,7 +591,11 @@ def fused_recurrent_gated_delta_rule_update_fwd(
scale=scale,
intermediate_states_buffer=intermediate_states_buffer,
cache_steps=0 if cache_steps is None else cache_steps,
retrieve_parent_token_ptr=retrieve_parent_token,
stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq,
stride_retrieve_parent_token_token=stride_retrieve_parent_token_token,
T=T,
NP2_T=NP2_T,
B=B,
H=H,
HV=HV,
......@@ -568,6 +634,7 @@ class FusedRecurrentUpdateFunction(torch.autograd.Function):
disable_output_calculation: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
cache_steps: Optional[int] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
):
o = fused_recurrent_gated_delta_rule_update_fwd(
q=q,
......@@ -584,6 +651,7 @@ class FusedRecurrentUpdateFunction(torch.autograd.Function):
disable_output_calculation=disable_output_calculation,
intermediate_states_buffer=intermediate_states_buffer,
cache_steps=cache_steps,
retrieve_parent_token=retrieve_parent_token,
)
return o
......@@ -613,6 +681,7 @@ def fused_recurrent_gated_delta_rule_update(
disable_output_calculation: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
cache_steps: Optional[int] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if cu_seqlens is not None:
if q.shape[0] != 1:
......@@ -649,5 +718,6 @@ def fused_recurrent_gated_delta_rule_update(
disable_output_calculation,
intermediate_states_buffer,
cache_steps,
retrieve_parent_token,
)
return o
......@@ -66,12 +66,19 @@ class MambaAttnBackendBase(AttentionBackend):
self.forward_metadata: ForwardMetadata = None
self.state_indices_list = []
self.query_start_loc_list = []
self.retrieve_next_token_list = []
self.retrieve_next_sibling_list = []
self.retrieve_parent_token_list = []
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
retrieve_next_token = None
retrieve_next_sibling = None
retrieve_parent_token = None
if forward_batch.forward_mode.is_decode_or_idle():
query_start_loc = torch.arange(
0, bs + 1, dtype=torch.int32, device=self.device
......@@ -85,6 +92,11 @@ class MambaAttnBackendBase(AttentionBackend):
dtype=torch.int32,
device=forward_batch.input_ids.device,
)
if forward_batch.spec_info.topk > 1:
retrieve_next_token = forward_batch.spec_info.retrive_next_token
retrieve_next_sibling = forward_batch.spec_info.retrive_next_sibling
retrieve_parent_token = torch.empty_like(retrieve_next_token)
else:
query_start_loc = torch.empty(
(bs + 1,), dtype=torch.int32, device=self.device
......@@ -102,6 +114,9 @@ class MambaAttnBackendBase(AttentionBackend):
return ForwardMetadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
retrieve_next_token=retrieve_next_token,
retrieve_next_sibling=retrieve_next_sibling,
retrieve_parent_token=retrieve_parent_token,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
......@@ -118,7 +133,7 @@ class MambaAttnBackendBase(AttentionBackend):
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.forward_metadata = self._capture_metadata(
bs, req_pool_indices, forward_mode
bs, req_pool_indices, forward_mode, spec_info
)
def init_forward_metadata_replay_cuda_graph(
......@@ -140,7 +155,7 @@ class MambaAttnBackendBase(AttentionBackend):
assert (
max_num_tokens % max_bs == 0
), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}"
verify_step = max_num_tokens / max_bs
draft_token_num = max_num_tokens // max_bs
for i in range(max_bs):
self.state_indices_list.append(
torch.full(
......@@ -150,19 +165,38 @@ class MambaAttnBackendBase(AttentionBackend):
self.query_start_loc_list.append(
torch.empty((i + 2,), dtype=torch.int32, device=self.device)
)
self.retrieve_next_token_list.append(
torch.zeros(
(i + 1, draft_token_num), dtype=torch.int32, device=self.device
)
)
self.retrieve_next_sibling_list.append(
torch.zeros(
(i + 1, draft_token_num), dtype=torch.int32, device=self.device
)
)
self.retrieve_parent_token_list.append(
torch.zeros(
(i + 1, draft_token_num), dtype=torch.int32, device=self.device
)
)
self.cached_cuda_graph_decode_query_start_loc = torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
)
self.cached_cuda_graph_verify_query_start_loc = torch.arange(
0,
max_bs * verify_step + 1,
step=verify_step,
max_bs * draft_token_num + 1,
step=draft_token_num,
dtype=torch.int32,
device=self.device,
)
def _capture_metadata(
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
self,
bs: int,
req_pool_indices: torch.Tensor,
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
......@@ -176,10 +210,24 @@ class MambaAttnBackendBase(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
if forward_mode.is_target_verify() and spec_info.topk > 1:
# They are None during cuda graph capture so skip the copy_...
# self.retrieve_next_token_list[bs - 1].copy_(spec_info.retrive_next_token)
# self.retrieve_next_sibling_list[bs - 1].copy_(spec_info.retrive_next_sibling)
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
retrieve_next_token=self.retrieve_next_token_list[bs - 1],
retrieve_next_sibling=self.retrieve_next_sibling_list[bs - 1],
retrieve_parent_token=self.retrieve_parent_token_list[bs - 1],
)
else:
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
def _replay_metadata(
self,
......@@ -224,10 +272,28 @@ class MambaAttnBackendBase(AttentionBackend):
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
if forward_mode.is_target_verify() and spec_info.topk > 1:
bs_without_pad = spec_info.retrive_next_token.shape[0]
# print(spec_info.retrive_next_token, spec_info.retrive_next_sibling)
self.retrieve_next_token_list[bs - 1][:bs_without_pad].copy_(
spec_info.retrive_next_token
)
self.retrieve_next_sibling_list[bs - 1][:bs_without_pad].copy_(
spec_info.retrive_next_sibling
)
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
retrieve_next_token=self.retrieve_next_token_list[bs - 1],
retrieve_next_sibling=self.retrieve_next_sibling_list[bs - 1],
retrieve_parent_token=self.retrieve_parent_token_list[bs - 1],
)
else:
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
def get_cuda_graph_seq_len_fill_value(self):
return 1 # Mamba attn does not use seq lens to index kv cache
......@@ -557,6 +623,9 @@ class GDNAttnBackend(MambaAttnBackendBase):
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
retrieve_next_token = self.forward_metadata.retrieve_next_token
retrieve_next_sibling = self.forward_metadata.retrieve_next_sibling
retrieve_parent_token = self.forward_metadata.retrieve_parent_token
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
......@@ -591,6 +660,9 @@ class GDNAttnBackend(MambaAttnBackendBase):
activation,
conv_state_indices=cache_indices[:batch_size],
intermediate_conv_window=intermediate_conv_window_cache,
retrieve_next_token=retrieve_next_token,
retrieve_next_sibling=retrieve_next_sibling,
retrieve_parent_token=retrieve_parent_token,
)
mixed_qkv = (
mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
......@@ -645,6 +717,7 @@ class GDNAttnBackend(MambaAttnBackendBase):
disable_state_update=True,
intermediate_states_buffer=intermediate_state_cache,
cache_steps=forward_batch.spec_info.draft_token_num,
retrieve_parent_token=retrieve_parent_token,
)
else:
recurrent_state = ssm_states[cache_indices]
......@@ -694,7 +767,7 @@ class Mamba2AttnBackend(MambaAttnBackendBase):
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode, spec_info)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
......@@ -891,8 +964,8 @@ class HybridLinearAttnBackend(AttentionBackend):
**kwargs,
)
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0]
def update_mamba_state_after_mtp_verify(self, accepted_indices, model):
request_number = accepted_indices.shape[0]
state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
......@@ -910,12 +983,11 @@ class HybridLinearAttnBackend(AttentionBackend):
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0
valid_mask = accepted_indices >= 0
# Compute common indices once to avoid duplication
last_steps_all = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) # [N]
last_steps = last_steps_all[valid_mask].to(torch.int64) # [N]
last_steps = accepted_indices[valid_mask].to(torch.int64) # [N]
# scatter into ssm_states at the chosen cache lines
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
......
......@@ -186,7 +186,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
)
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
# tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else:
......@@ -221,7 +221,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
# tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
......@@ -552,6 +552,21 @@ def causal_conv1d_fn(
return out
# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask
# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]
# e.g. for a sequence of length 4, the eagle tree attention structure is:
# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i
# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i
# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i
# Tree:
# 0
# / \
# 1 2
# /
# 3
# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent)
# When calculating token 2's convolution, it should conv to token 0 (parent)
# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
......@@ -563,6 +578,9 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr,
num_accepted_tokens_ptr,
intermediate_conv_window_ptr,
retrieve_next_token_ptr,
retrieve_next_sibling_ptr,
retrieve_parent_token_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
......@@ -584,6 +602,12 @@ def _causal_conv1d_update_kernel(
stride_inter_step: tl.constexpr,
stride_inter_dim: tl.constexpr,
stride_inter_win: tl.constexpr,
stride_retrieve_next_token_seq: tl.constexpr,
stride_retrieve_next_token_token: tl.constexpr,
stride_retrieve_next_sibling_seq: tl.constexpr,
stride_retrieve_next_sibling_token: tl.constexpr,
stride_retrieve_parent_token_seq: tl.constexpr,
stride_retrieve_parent_token_token: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
......@@ -596,9 +620,11 @@ def _causal_conv1d_update_kernel(
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
NP2_SEQLEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
SAVE_INTERMEDIATE: tl.constexpr,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
......@@ -695,7 +721,7 @@ def _causal_conv1d_update_kernel(
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
# tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
......@@ -723,6 +749,24 @@ def _causal_conv1d_update_kernel(
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
idx_tokens = tl.arange(0, NP2_SEQLEN) # [BLOCK_M]
# Update parent mapping for all tokens at once using vectorized operations
mask_retrieve = idx_tokens < seqlen
retrieve_next_token_base = (
retrieve_next_token_ptr
+ (idx_seq * stride_retrieve_next_token_seq)
+ idx_tokens * stride_retrieve_next_token_token
)
retrieve_next_tokens = tl.load(retrieve_next_token_base, mask_retrieve)
retrieve_next_sibling_base = (
retrieve_next_sibling_ptr
+ (idx_seq * stride_retrieve_next_sibling_seq)
+ idx_tokens * stride_retrieve_next_sibling_token
)
retrieve_next_siblings = tl.load(retrieve_next_sibling_base, mask_retrieve)
parent_idx_tokens = tl.zeros((NP2_SEQLEN,), dtype=tl.int32)
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
......@@ -744,45 +788,162 @@ def _causal_conv1d_update_kernel(
for idx_token in tl.static_range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
# set the parent index of the next token in the eagle tree
# next token's parent is the current token
retrieve_next_token_idx = tl.sum(
tl.where(idx_tokens == idx_token, retrieve_next_tokens, 0)
)
if retrieve_next_token_idx != -1: # pad slot id
parent_idx_tokens = tl.where(
idx_tokens == retrieve_next_token_idx,
idx_token,
parent_idx_tokens,
)
# next token's parent is the parent of the current token
retrieve_sibling_token_idx = tl.sum(
tl.where(idx_tokens == idx_token, retrieve_next_siblings, 0)
)
if retrieve_sibling_token_idx != -1: # pad slot id
parent_idx_token = tl.sum(
tl.where(idx_tokens == idx_token, parent_idx_tokens, 0)
)
parent_idx_tokens = tl.where(
idx_tokens == retrieve_sibling_token_idx,
parent_idx_token,
parent_idx_tokens,
)
# tl.device_print("am", parent_idx_tokens)
_idx_token = idx_token
x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
# convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ...
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 0:
matrix_w = w_col1
else:
matrix_w = w_col0
elif KERNEL_WIDTH == 3:
if j == 0:
matrix_w = w_col2
elif j == 1:
matrix_w = w_col1
else:
matrix_w = w_col0
elif KERNEL_WIDTH == 4:
if j == 0:
matrix_w = w_col3
elif j == 1:
matrix_w = w_col2
elif j == 2:
matrix_w = w_col1
else:
matrix_w = w_col0
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (
intermediate_conv_window_ptr
+ conv_state_batch_coord * stride_inter_seq
+ idx_token * stride_inter_step
+ idx_feats * stride_inter_dim
)
# store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ...
if KERNEL_WIDTH - j - 2 >= 0:
tl.store(
base_ptr + (KERNEL_WIDTH - j - 2) * stride_inter_win,
matrix_x,
mask=mask_w,
)
acc += matrix_x * matrix_w
# move to parent for next iteration
if _idx_token > 0:
_idx_token = tl.sum(
tl.where(idx_tokens == _idx_token, parent_idx_tokens, 0)
)
x_ptrs_1d = x_base_1d + _idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
else:
# no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ...
if KERNEL_WIDTH == 2:
if _idx_token == 0:
matrix_x = col0
elif KERNEL_WIDTH == 3:
if _idx_token == 0:
matrix_x = col1
else:
matrix_x = col0
elif KERNEL_WIDTH == 4:
if _idx_token == 0:
matrix_x = col2
elif _idx_token == -1:
matrix_x = col1
else:
matrix_x = col0
_idx_token = _idx_token - 1
else:
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
col0 = col1
col1 = col2
col2 = matrix_x
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (
intermediate_conv_window_ptr
+ conv_state_batch_coord * stride_inter_seq
+ idx_token * stride_inter_step
+ idx_feats * stride_inter_dim
)
if KERNEL_WIDTH >= 2:
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
if KERNEL_WIDTH >= 3:
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
if KERNEL_WIDTH >= 4:
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
......@@ -798,21 +959,15 @@ def _causal_conv1d_update_kernel(
tl.store(o_ptrs, acc, mask=mask_1d)
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (
intermediate_conv_window_ptr
+ conv_state_batch_coord * stride_inter_seq
+ idx_token * stride_inter_step
+ idx_feats * stride_inter_dim
# fuse: store calculated retrieve_parent_token to tensor
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
tl.store(
retrieve_parent_token_ptr
+ idx_seq * stride_retrieve_parent_token_seq
+ idx_tokens * stride_retrieve_parent_token_token,
parent_idx_tokens,
mask=mask_retrieve,
)
if KERNEL_WIDTH >= 2:
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
if KERNEL_WIDTH >= 3:
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
if KERNEL_WIDTH >= 4:
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
def causal_conv1d_update(
......@@ -825,6 +980,9 @@ def causal_conv1d_update(
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
intermediate_conv_window: Optional[torch.Tensor] = None,
retrieve_next_token: Optional[torch.Tensor] = None,
retrieve_next_sibling: Optional[torch.Tensor] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
......@@ -888,7 +1046,7 @@ def causal_conv1d_update(
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
out = torch.empty_like(x)
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen)
......@@ -903,6 +1061,7 @@ def causal_conv1d_update(
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
np2_seqlen = triton.next_power_of_2(seqlen)
def grid(META):
return (
......@@ -921,6 +1080,33 @@ def causal_conv1d_update(
else:
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
# prepare retrieve next token buffer strides if provided
if retrieve_next_token is not None:
stride_retrieve_next_token_seq, stride_retrieve_next_token_token = (
retrieve_next_token.stride(0),
retrieve_next_token.stride(1),
)
else:
stride_retrieve_next_token_seq = stride_retrieve_next_token_token = 0
# prepare retrieve next sibling buffer strides if provided
if retrieve_next_sibling is not None:
stride_retrieve_next_sibling_seq, stride_retrieve_next_sibling_token = (
retrieve_next_sibling.stride(0),
retrieve_next_sibling.stride(1),
)
else:
stride_retrieve_next_sibling_seq = stride_retrieve_next_sibling_token = 0
# prepare retrieve parent token buffer strides if provided
if retrieve_parent_token is not None:
stride_retrieve_parent_token_seq, stride_retrieve_parent_token_token = (
retrieve_parent_token.stride(0),
retrieve_parent_token.stride(1),
)
else:
stride_retrieve_parent_token_seq = stride_retrieve_parent_token_token = 0
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
......@@ -931,6 +1117,9 @@ def causal_conv1d_update(
conv_state_indices,
num_accepted_tokens,
intermediate_conv_window if intermediate_conv_window is not None else x,
retrieve_next_token,
retrieve_next_sibling,
retrieve_parent_token,
out,
# Matrix dimensions
batch,
......@@ -952,6 +1141,12 @@ def causal_conv1d_update(
stride_inter_step,
stride_inter_dim,
stride_inter_win,
stride_retrieve_next_token_seq,
stride_retrieve_next_token_token,
stride_retrieve_next_sibling_seq,
stride_retrieve_next_sibling_token,
stride_retrieve_parent_token_seq,
stride_retrieve_parent_token_token,
stride_o_seq,
stride_o_dim,
stride_o_token,
......@@ -964,9 +1159,11 @@ def causal_conv1d_update(
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
NP2_SEQLEN=np2_seqlen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_next_token is not None,
)
if unsqueeze:
out = out.squeeze(-1)
......
......@@ -16,6 +16,7 @@
import math
from dataclasses import dataclass
from typing import Optional
import torch
......@@ -26,6 +27,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ForwardMetadata:
query_start_loc: torch.Tensor
mamba_cache_indices: torch.Tensor
retrieve_next_token: Optional[torch.Tensor] = None
retrieve_next_sibling: Optional[torch.Tensor] = None
retrieve_parent_token: Optional[torch.Tensor] = None
@dataclass(kw_only=True)
......
......@@ -694,19 +694,45 @@ class EAGLEWorker(TpModelWorker):
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized
if self.target_worker.model_runner.hybrid_gdn_config is not None:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch.tensor(
res.accept_length_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int32,
dtype=torch.int64,
)
+ 1
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
# res.accepted_indices.shape[0] > 0 skips DP attn idle batch
if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0:
# accepted_indices=[0,2,3,4,5,7,9,10,11], accepted_length=[4, 3, 2], cumulative_accepted_lengths=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_accepted_lengths[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_accepted_lengths - 1] = [4, 9, 11] (last token ID of each req)
# max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
cumulative_accepted_lengths = torch.cumsum(accepted_length, dim=0)
req_start_positions = torch.cat(
[
torch.zeros(
1,
dtype=cumulative_accepted_lengths.dtype,
device=cumulative_accepted_lengths.device,
),
cumulative_accepted_lengths[:-1],
]
)
first_token_indices_per_req = res.accepted_indices[req_start_positions]
last_token_indices_per_req = res.accepted_indices[
cumulative_accepted_lengths - 1
]
max_relative_indices_per_req = (
last_token_indices_per_req - first_token_indices_per_req
)
else:
max_relative_indices_per_req = accepted_length - 1
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accepted_length, self.target_worker.model_runner.model
max_relative_indices_per_req, self.target_worker.model_runner.model
)
if batch.return_logprob:
......
......@@ -59,11 +59,56 @@ class TestQwen3NextMTP(CustomTestCase):
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
"1",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"2",
"4",
"--mem-fraction-static",
"0.8",
"--tp",
"4",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.93)
class TestQwen3NextMTPTopk(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--mem-fraction-static",
"0.8",
"--tp",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment