Unverified Commit 6c18ab46 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300...


[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300 tok/s to 341 tok/s (#10335)
Co-authored-by: default avatarBinyao Jiang <byjiang1996@gmail.com>
parent 4a0e0be2
...@@ -13,7 +13,7 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( ...@@ -13,7 +13,7 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from sglang.srt.layers.attention.mamba.causal_conv1d import ( from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn, causal_conv1d_fn,
causal_conv1d_update, causal_conv1d_update,
) )
...@@ -195,7 +195,9 @@ class MambaAttnBackend(AttentionBackend): ...@@ -195,7 +195,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"] dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"] layer_id = kwargs["layer_id"]
conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id) conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
query_start_loc = self.forward_metadata.query_start_loc query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices cache_indices = self.forward_metadata.mamba_cache_indices
...@@ -277,12 +279,9 @@ class MambaAttnBackend(AttentionBackend): ...@@ -277,12 +279,9 @@ class MambaAttnBackend(AttentionBackend):
( (
conv_states, conv_states,
ssm_states, ssm_states,
mixed_qkv_cache,
intermediate_state_cache, intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id) ) = self.req_to_token_pool.get_mamba_params(layer_id)
mixed_qkv_cache[cache_indices] = mixed_qkv.view(
(-1,) + mixed_qkv_cache.shape[1:]
).clone()
has_initial_states = torch.ones( has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num, seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool, dtype=torch.bool,
...@@ -295,16 +294,38 @@ class MambaAttnBackend(AttentionBackend): ...@@ -295,16 +294,38 @@ class MambaAttnBackend(AttentionBackend):
) )
has_initial_states = forward_batch.extend_prefix_lens > 0 has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states conv_states_to_use = conv_states
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1), if is_target_verify:
conv_weights, batch_size = seq_len // forward_batch.spec_info.draft_token_num
bias, draft_token_num = forward_batch.spec_info.draft_token_num
activation=activation, mixed_qkv_reshaped = (
conv_states=conv_states_to_use, mixed_qkv.view(batch_size, draft_token_num, -1)
has_initial_state=has_initial_states, .transpose(1, 2)
cache_indices=cache_indices, .contiguous()
query_start_loc=query_start_loc, )
).transpose(0, 1)[:seq_len] mixed_qkv_processed = causal_conv1d_update(
mixed_qkv_reshaped,
conv_states_to_use,
conv_weights,
bias,
activation,
conv_state_indices=cache_indices[:batch_size],
intermediate_conv_window=intermediate_conv_window_cache,
)
mixed_qkv = (
mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
)
else:
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
conv_weights,
bias,
activation=activation,
conv_states=conv_states_to_use,
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
).transpose(0, 1)[:seq_len]
key_split_dim = key_dim // attn_tp_size key_split_dim = key_dim // attn_tp_size
value_split_dim = value_dim // attn_tp_size value_split_dim = value_dim // attn_tp_size
...@@ -507,26 +528,6 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -507,26 +528,6 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model): def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0] request_number = accepted_length.shape[0]
# QQ: step = spec num_draft token num
num_draft_tokens = (
self.attn_backend_list[1]
.req_to_token_pool.mamba_pool.mamba_cache[2]
.shape[2]
)
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
query_start_loc = torch.cat(
[
torch.zeros(
1,
dtype=query_start_loc.dtype,
device=query_start_loc.device,
),
query_start_loc,
]
)
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
0
) < accepted_length.unsqueeze(1)
state_indices_tensor = self.attn_backend_list[ state_indices_tensor = self.attn_backend_list[
1 1
...@@ -536,46 +537,48 @@ class HybridLinearAttnBackend(AttentionBackend): ...@@ -536,46 +537,48 @@ class HybridLinearAttnBackend(AttentionBackend):
1 1
].req_to_token_pool.get_mamba_params_all_layers() ].req_to_token_pool.get_mamba_params_all_layers()
conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches (
conv_states,
mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask] ssm_states,
intermediate_state_cache,
mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map intermediate_conv_window_cache,
) = mamba_caches
has_initial_states = torch.ones(
request_number, dtype=torch.bool, device=accepted_length.device
)
# Batch SSM state updates (outside the loop for efficiency) # SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0 valid_mask = accepted_length > 0
if intermediate_state_cache is not None:
last_steps = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype)
# For loop conv state updates (can be optimized)
for i in range(len(model.model.layers)):
layer = model.model.layers[i]
if isinstance(layer, Qwen3HybridLinearDecoderLayer):
conv_weights = layer.linear_attn.conv1d.weight.view(
layer.linear_attn.conv1d.weight.size(0),
layer.linear_attn.conv1d.weight.size(2),
)
layer_id = mamba_map[i] # Compute common indices once to avoid duplication
conv_state = conv_states[layer_id] last_steps_all = (accepted_length - 1).to(torch.int64)
mixed_qkv = mixed_qkvs[layer_id] valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
last_steps = last_steps_all[valid_mask].to(torch.int64)
_ = causal_conv1d_fn(
mixed_qkv.transpose(0, 1), if valid_state_indices.numel() > 0:
conv_weights, chunk = 256
layer.linear_attn.conv1d.bias, num_valid = valid_state_indices.numel()
activation=layer.linear_attn.activation,
conv_states=conv_state, # SSM state updates
has_initial_state=has_initial_states, for i in range(0, num_valid, chunk):
cache_indices=state_indices_tensor, idx = valid_state_indices[i : i + chunk]
query_start_loc=query_start_loc, steps = last_steps[i : i + chunk]
) # per (cache line, step)
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
ssm_states[:, ci, :].copy_(
intermediate_state_cache[:, ci, st].to(
ssm_states.dtype, copy=False
)
)
# Conv window updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
conv_states[:, ci, :, :].copy_(
intermediate_conv_window_cache[:, ci, st].to(
conv_states.dtype, copy=False
)
)
...@@ -125,16 +125,6 @@ class MambaPool: ...@@ -125,16 +125,6 @@ class MambaPool:
device=device, device=device,
) )
if speculative_num_draft_tokens is not None: if speculative_num_draft_tokens is not None:
mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
),
dtype=conv_dtype,
device="cuda",
)
# Cache intermediate SSM states per draft token during target verify # Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty( intermediate_ssm_state_cache = torch.empty(
...@@ -149,11 +139,24 @@ class MambaPool: ...@@ -149,11 +139,24 @@ class MambaPool:
dtype=ssm_dtype, dtype=ssm_dtype,
device="cuda", device="cuda",
) )
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = ( self.mamba_cache = (
conv_state, conv_state,
temporal_state, temporal_state,
mixed_qkv_cache,
intermediate_ssm_state_cache, intermediate_ssm_state_cache,
intermediate_conv_window_cache,
) )
else: else:
self.mamba_cache = (conv_state, temporal_state) self.mamba_cache = (conv_state, temporal_state)
......
import bisect
from typing import TYPE_CHECKING, Callable
import torch
import torch.nn.functional as F
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
)
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class MambaStateUpdateCudaGraphRunner:
def __init__(self, eagle_worker: "EAGLEWorker"):
self.eagle_worker = eagle_worker
model_runner = eagle_worker.target_worker.model_runner
self.model_runner = model_runner
self.attn_backend = model_runner.attn_backend.attn_backend_list[1]
self.req_to_token_pool = self.attn_backend.req_to_token_pool
self.graphs = {}
self.output_buffers = {}
self.graph_input_buffer = None
self.stream = torch.cuda.Stream()
self.model = model_runner.model
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.max_bs = self.capture_bs[-1]
self.init_cuda_graph_state()
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
def init_cuda_graph_state(self):
self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache
self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2]
num_mamba_layers = self.mamba_cache[0].shape[0]
conv_dtype = torch.bfloat16
conv_shape = self.mamba_cache[0].shape[2]
total_token_number = self.max_accepted_tokens * self.max_bs
self.mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
total_token_number,
conv_shape,
),
dtype=conv_dtype,
device="cuda",
)
self.query_start_loc = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.state_indices = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.has_initial_states = torch.ones(
self.max_bs, dtype=torch.bool, device="cuda"
)
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, bs: int, forward: Callable):
"""
Capture CUDA Graph for a typical workload
"""
graph = torch.cuda.CUDAGraph()
stream = self.stream
total_token_number = bs * self.max_accepted_tokens
mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number]
query_start_loc = self.query_start_loc[: bs + 1]
state_indices = self.state_indices[:bs]
has_initial_states = self.has_initial_states[:bs]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
conv_states = mamba_caches[0]
mamba_map = self.req_to_token_pool.mamba_map
def run_once():
for i in range(len(self.model.model.layers)):
layer = self.model.model.layers[i]
if not isinstance(layer, Qwen3HybridLinearDecoderLayer):
continue
conv_weights = layer.linear_attn.conv1d.weight.view(
layer.linear_attn.conv1d.weight.size(0),
layer.linear_attn.conv1d.weight.size(2),
)
layer_id = mamba_map[i]
causal_conv1d_fn(
mixed_qkvs[layer_id].transpose(0, 1),
conv_weights,
layer.linear_attn.conv1d.bias,
activation=layer.linear_attn.activation,
conv_states=conv_states[layer_id],
has_initial_state=has_initial_states,
cache_indices=state_indices,
query_start_loc=query_start_loc,
)
return None
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
set_global_graph_memory_pool(graph.pool())
return graph, out
def can_run(self, accepted_length):
bs = accepted_length.shape[0]
return bs <= self.max_bs
def replay_repare(self, accepted_length):
request_number = accepted_length.shape[0]
# QQ: step = spec num_draft token num
num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2]
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
query_start_loc = torch.cat(
[
torch.zeros(
1,
dtype=query_start_loc.dtype,
device=query_start_loc.device,
),
query_start_loc,
]
)
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
0
) < accepted_length.unsqueeze(1)
state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
_, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask]
self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs)
self.query_start_loc[: request_number + 1] = query_start_loc
self.query_start_loc[request_number + 1 :] = self.query_start_loc[
request_number
]
self.state_indices[:request_number] = state_indices_tensor
self.state_indices[request_number:] = -1
valid_mask = accepted_length > 0
if intermediate_state_cache is not None:
last_steps = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype)
def replay(self, accepted_length):
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = accepted_length.shape[0]
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
self.replay_repare(accepted_length)
# Replay
self.graphs[bs].replay()
...@@ -407,15 +407,6 @@ class EAGLEWorker(TpModelWorker): ...@@ -407,15 +407,6 @@ class EAGLEWorker(TpModelWorker):
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
) )
if self.target_worker.model_runner.is_hybrid_gdn:
from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import (
MambaStateUpdateCudaGraphRunner,
)
self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner(
self
)
@property @property
def draft_model_runner(self): def draft_model_runner(self):
return self.model_runner return self.model_runner
...@@ -848,12 +839,9 @@ class EAGLEWorker(TpModelWorker): ...@@ -848,12 +839,9 @@ class EAGLEWorker(TpModelWorker):
) )
+ 1 + 1
) )
if self.cuda_graph_runner_for_target_verify.can_run(accepted_length): self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
self.cuda_graph_runner_for_target_verify.replay(accepted_length) accepted_length, self.target_worker.model_runner.model
else: )
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accepted_length, self.target_worker.model_runner.model
)
if batch.return_logprob: if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output) self.add_logprob_values(batch, res, logits_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment