"vscode:/vscode.git/clone" did not exist on "eb934bdf3ba1f8d14c35d7808ba53c7afbb6531f"
Unverified Commit 6c18ab46 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300...


[Qwen3-Next] switch to triton and cache conv states to accelerate MTP from 300 tok/s to 341 tok/s (#10335)
Co-authored-by: default avatarBinyao Jiang <byjiang1996@gmail.com>
parent 4a0e0be2
......@@ -13,7 +13,7 @@ from sglang.srt.layers.attention.fla.fused_recurrent import (
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d import (
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn,
causal_conv1d_update,
)
......@@ -195,7 +195,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"]
conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id)
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
......@@ -277,12 +279,9 @@ class MambaAttnBackend(AttentionBackend):
(
conv_states,
ssm_states,
mixed_qkv_cache,
intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id)
mixed_qkv_cache[cache_indices] = mixed_qkv.view(
(-1,) + mixed_qkv_cache.shape[1:]
).clone()
has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool,
......@@ -295,16 +294,38 @@ class MambaAttnBackend(AttentionBackend):
)
has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
conv_weights,
bias,
activation=activation,
conv_states=conv_states_to_use,
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
).transpose(0, 1)[:seq_len]
if is_target_verify:
batch_size = seq_len // forward_batch.spec_info.draft_token_num
draft_token_num = forward_batch.spec_info.draft_token_num
mixed_qkv_reshaped = (
mixed_qkv.view(batch_size, draft_token_num, -1)
.transpose(1, 2)
.contiguous()
)
mixed_qkv_processed = causal_conv1d_update(
mixed_qkv_reshaped,
conv_states_to_use,
conv_weights,
bias,
activation,
conv_state_indices=cache_indices[:batch_size],
intermediate_conv_window=intermediate_conv_window_cache,
)
mixed_qkv = (
mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1)
)
else:
mixed_qkv = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
conv_weights,
bias,
activation=activation,
conv_states=conv_states_to_use,
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
).transpose(0, 1)[:seq_len]
key_split_dim = key_dim // attn_tp_size
value_split_dim = value_dim // attn_tp_size
......@@ -507,26 +528,6 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0]
# QQ: step = spec num_draft token num
num_draft_tokens = (
self.attn_backend_list[1]
.req_to_token_pool.mamba_pool.mamba_cache[2]
.shape[2]
)
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
query_start_loc = torch.cat(
[
torch.zeros(
1,
dtype=query_start_loc.dtype,
device=query_start_loc.device,
),
query_start_loc,
]
)
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
0
) < accepted_length.unsqueeze(1)
state_indices_tensor = self.attn_backend_list[
1
......@@ -536,46 +537,48 @@ class HybridLinearAttnBackend(AttentionBackend):
1
].req_to_token_pool.get_mamba_params_all_layers()
conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask]
mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map
has_initial_states = torch.ones(
request_number, dtype=torch.bool, device=accepted_length.device
)
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = mamba_caches
# Batch SSM state updates (outside the loop for efficiency)
# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0
if intermediate_state_cache is not None:
last_steps = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype)
# For loop conv state updates (can be optimized)
for i in range(len(model.model.layers)):
layer = model.model.layers[i]
if isinstance(layer, Qwen3HybridLinearDecoderLayer):
conv_weights = layer.linear_attn.conv1d.weight.view(
layer.linear_attn.conv1d.weight.size(0),
layer.linear_attn.conv1d.weight.size(2),
)
layer_id = mamba_map[i]
conv_state = conv_states[layer_id]
mixed_qkv = mixed_qkvs[layer_id]
_ = causal_conv1d_fn(
mixed_qkv.transpose(0, 1),
conv_weights,
layer.linear_attn.conv1d.bias,
activation=layer.linear_attn.activation,
conv_states=conv_state,
has_initial_state=has_initial_states,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc,
)
# Compute common indices once to avoid duplication
last_steps_all = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
last_steps = last_steps_all[valid_mask].to(torch.int64)
if valid_state_indices.numel() > 0:
chunk = 256
num_valid = valid_state_indices.numel()
# SSM state updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
# per (cache line, step)
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
ssm_states[:, ci, :].copy_(
intermediate_state_cache[:, ci, st].to(
ssm_states.dtype, copy=False
)
)
# Conv window updates
for i in range(0, num_valid, chunk):
idx = valid_state_indices[i : i + chunk]
steps = last_steps[i : i + chunk]
for j in range(idx.numel()):
ci = idx[j].item()
st = steps[j].item()
conv_states[:, ci, :, :].copy_(
intermediate_conv_window_cache[:, ci, st].to(
conv_states.dtype, copy=False
)
)
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
from typing import Optional, Union
import numpy as np
import torch
PAD_SLOT_ID = -1
import triton
import triton.language as tl
@triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching
# Pointers to matrices
x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences
w_ptr, # (dim, width)
bias_ptr,
initial_states_ptr, # conv_states_ptr
cache_indices_ptr, # conv_state_indices_ptr
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch: tl.int32, # actually padded_batch
dim: tl.constexpr,
seqlen: tl.int32, # cu_seqlen
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr, # stride to get to next sequence,
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_istate_seq: tl.constexpr,
stride_istate_dim: tl.constexpr,
stride_istate_token: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
HAS_INITIAL_STATES: tl.constexpr,
HAS_CACHE: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
NP2_STATELEN: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
conv_states_ptr = initial_states_ptr
conv_state_indices_ptr = cache_indices_ptr
stride_conv_state_seq = stride_istate_seq
stride_conv_state_dim = stride_istate_dim
stride_conv_state_tok = stride_istate_token
state_len = (
KERNEL_WIDTH - 1
) # can be passed via argument if it's not the same as this value
# one program handles one chunk in a single sequence
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0))
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
# BLOCK_N elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if idx_seq == pad_slot_id:
return
sequence_start_index = tl.load(query_start_loc_ptr + idx_seq)
sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1)
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = (
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
) # [BLOCK_N,]
if IS_CONTINUOUS_BATCHING:
# cache_idx
conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64)
else:
# cache_idx
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = (
conv_states_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)
) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
# Does 2 things:
# 1. READ prior-block init-state data - [done by every Triton programs]
# 2. update conv_state with new data [only by the Triton program handles chunk_offset=0]
if chunk_offset == 0:
# read from conv_states
load_init_state = False
if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 3:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 4:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
else:
# prior-tokens are zeros
if KERNEL_WIDTH >= 2: # STRATEGY1
# first chunk and does not have prior-token, so just set to 0
col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 3: # STRATEGY1
col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 4: # STRATEGY1
col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 5: # STRATEGY1
col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
# STEP 2:
# here prepare data for updating conv_state
if (
state_len <= seqlen
): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN
) # [BLOCK_M]
x_ptrs = (
x_ptr
+ ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
+ (idx_feats * stride_x_dim)[None, :]
) # [BLOCK_M,BLOCK_N,]
mask_x = (
(idx_tokens_last >= 0)[:, None]
& (idx_tokens_last < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_target = (
conv_states_base[None, :]
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
)
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else:
if load_init_state:
# update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :]
)
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = (
x_base[None, :]
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = (
(idx_tokens_conv - VAL >= 0)[:, None]
& (idx_tokens_conv - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = (
conv_states_base
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
None, :
]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # load_init_state == False
# update conv_state by shifting left, BUT
# set cols prior to 'x' as zeros + cols from 'x'
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
VAL = state_len - seqlen
x_ptrs = (
x_base[None, :]
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = (
(idx_tokens_conv - VAL >= 0)[:, None]
& (idx_tokens_conv - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = (
conv_states_base
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
None, :
]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # chunk_offset > 0
# read prior-token data from `x`
load_init_state = True
prior_tokens = x_base + (token_offset - 1) * stride_x_token
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 3:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 4:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 5:
# ruff: noqa: F841
conv_states_ptrs = prior_tokens # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
tl.float32
) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
# PRE-LOAD WEIGHTS
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
mask_x_1d = idx_feats < dim
for idx_token in range(segment_len):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < segment_len) & (
idx_feats < dim
) # token-index # feature-index
o_ptrs = (
o_ptr
+ (sequence_start_index + token_offset + idx_token) * stride_o_token
+ (idx_feats * stride_o_dim)
)
tl.store(o_ptrs, acc, mask=mask_1d)
def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Union[torch.Tensor, None],
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""support varlen + continuous batching when x is 2D tensor
x: (dim,cu_seq_len)
cu_seq_len = total tokens of all seqs in that batch
sequences are concatenated from left to right for varlen
weight: (dim, width)
conv_states: (...,dim,width - 1) itype
updated inplace if provided
[it use `cache_indices` to get the index to the cache of conv_state for that sequence
conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True
and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x'
]
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended by 0.
if
x = [5, 1, 1, 1] <- continuous batching (batch=4)
then
query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is
the ending index of the last sequence
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
has_initial_state: (batch) bool
indicates whether should the kernel take the current state as initial
state for the calculations
[single boolean for each sequence in the batch: True or False]
bias: (dim,)
activation: either None or "silu" or "swish" or True
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: same shape as `x`
"""
if isinstance(activation, bool) and activation:
activation = "silu"
args = None
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
# x = metadata.x
args = nums_dict
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = np.diff(query_start_loc.to("cpu"))
args = seqlens
MAX_NUM_PROGRAMS = 1024
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
dim, cu_seqlen = x.shape
_, width = weight.shape
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_seq = 0
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
stride_w_dim = weight.stride(0)
stride_w_width = weight.stride(1)
stride_istate_seq = 0
stride_istate_dim = 0
stride_istate_token = 0
num_cache_lines = 0
if conv_states is not None:
# extensions to support vLLM:
# 1. conv_states is used to replaced initial_states
# 2. conv_states serve as a cache with num cache lines can be larger than batch size
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (
num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2]
)
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
# assert stride_istate_dim == 1
if out.dim() == 2:
stride_o_seq = 0
stride_o_dim = out.stride(0)
stride_o_token = out.stride(1)
else:
stride_o_seq = out.stride(0)
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
if validate_data:
assert x.dim() == 2
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
if cache_indices is not None:
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch,)
assert (
conv_states is not None
), "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if metadata is None:
def num_program(META, seqlens):
tot = 0
mlist = []
offsetlist = [] # type: ignore
nums = -(-seqlens // META["BLOCK_M"])
tot = nums.sum().item()
mlist = np.repeat(np.arange(len(nums)), nums)
for idx, num in enumerate(nums):
offsetlist.extend(
range(num)
) # chunk-idx if a sequence is split into multiple chunks
if META["batch_ptr"].nelement() < len(mlist):
newlen = len(mlist) + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= len(mlist):
META["batch_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(mlist))
)
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(offsetlist))
)
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
META["x_ptr"].device
)
return tot
else:
def num_program(META, nums_dict):
tot = nums_dict[META["BLOCK_M"]]["tot"]
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
"token_chunk_offset_ptr"
]
else:
if META["batch_ptr"].nelement() < mlist_len:
newlen = mlist_len + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= mlist_len:
META["batch_ptr"][0:mlist_len].copy_(mlist)
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
return tot
def grid(META):
return (
num_program(META, args),
triton.cdiv(dim, META["BLOCK_N"]),
)
if batch_ptr.device != x.device:
batch_ptr = batch_ptr.to(x.device)
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_states,
cache_indices,
has_initial_state,
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
out,
# Matrix dimensions
padded_batch,
dim,
cu_seqlen,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
HAS_INITIAL_STATES=has_initial_state is not None,
HAS_CACHE=conv_states is not None,
IS_CONTINUOUS_BATCHING=cache_indices is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
# launch_cooperative_grid=True
BLOCK_M=8,
BLOCK_N=256,
num_stages=2,
)
return out
@triton.jit()
def _causal_conv1d_update_kernel(
# Pointers to matrices
x_ptr, # (batch, dim, seqlen)
w_ptr, # (dim, width)
bias_ptr,
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
intermediate_conv_window_ptr,
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
dim: tl.constexpr,
seqlen: tl.constexpr,
state_len: tl.constexpr,
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_seq: tl.constexpr,
stride_x_dim: tl.constexpr,
stride_x_token: tl.constexpr,
stride_w_dim: tl.constexpr,
stride_w_width: tl.constexpr,
stride_conv_state_seq: tl.constexpr,
stride_conv_state_dim: tl.constexpr,
stride_conv_state_tok: tl.constexpr,
stride_state_indices: tl.constexpr,
stride_inter_seq: tl.constexpr,
stride_inter_step: tl.constexpr,
stride_inter_dim: tl.constexpr,
stride_inter_win: tl.constexpr,
stride_o_seq: tl.constexpr,
stride_o_dim: tl.constexpr,
stride_o_token: tl.constexpr,
# others
pad_slot_id: tl.constexpr,
# Meta-parameters
HAS_BIAS: tl.constexpr,
KERNEL_WIDTH: tl.constexpr,
SILU_ACTIVATION: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
NP2_STATELEN: tl.constexpr,
USE_PAD_SLOT: tl.constexpr,
BLOCK_N: tl.constexpr,
SAVE_INTERMEDIATE: tl.constexpr,
):
# ruff: noqa: E501
idx_seq = tl.program_id(0)
if idx_seq >= batch:
return
# [BLOCK_N,] elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
if IS_CONTINUOUS_BATCHING:
# mask = idx_seq < batch
conv_state_batch_coord = tl.load(
conv_state_indices_ptr + idx_seq * stride_state_indices
).to(tl.int64)
else:
conv_state_batch_coord = idx_seq
if USE_PAD_SLOT: # noqa
if conv_state_batch_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
if IS_SPEC_DECODING:
# The rolling of conv state:
#
# Before forward, the conv_state is:
# [history1, history2, ..., historyM].
#
# After forward, the conv_state becomes:
# [history2, ..., historyM, draft1, draft2, ..., draftN].
#
# After acceptance, it becomes:
#
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (
conv_state_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)
)
mask_w = idx_feats < dim
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
if KERNEL_WIDTH >= 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 3:
conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH >= 4:
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
if KERNEL_WIDTH == 5:
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
# STEP 2: assume state_len > seqlen
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# The conv_state updates works in a sliding window manner,
# at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :]
)
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N]
x_ptrs = (
x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = (
(idx_tokens - VAL >= 0)[:, None]
& (idx_tokens - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
new_conv_state = tl.where(mask, conv_state, loaded_x)
conv_state_base = (
conv_state_ptr
+ (conv_state_batch_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)
) # [BLOCK_N,]
conv_state_ptrs_target = (
conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
# STEP 3: init accumulator
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
tl.float32
) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
mask_w = idx_feats < dim
if KERNEL_WIDTH >= 2:
w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor
w_col0 = tl.load(w_ptrs, mask_w, other=0.0)
w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor
w_col1 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 3:
w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor
w_col2 = tl.load(w_ptrs, mask_w, other=0.0)
if KERNEL_WIDTH >= 4:
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
x_base_1d = x_base # starting of chunk [BLOCK_N]
mask_x_1d = idx_feats < dim
# STEP 5: compute each token
for idx_token in tl.static_range(seqlen):
acc = acc_preload
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 3:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
elif KERNEL_WIDTH == 4:
if j == 1:
matrix_w = w_col1
matrix_x = col1
elif j == 2:
matrix_w = w_col2
matrix_x = col2
elif j == 3:
matrix_w = w_col3
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
acc += matrix_x * matrix_w # [BLOCK_N]
if KERNEL_WIDTH == 2:
col0 = matrix_x
elif KERNEL_WIDTH == 3:
col0 = col1
col1 = matrix_x
elif KERNEL_WIDTH == 4:
col0 = col1
col1 = col2
col2 = matrix_x
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (
idx_feats < dim
) # token-index # feature-index
o_ptrs = (
o_ptr
+ (idx_seq) * stride_o_seq
+ idx_token * stride_o_token
+ (idx_feats * stride_o_dim)
)
tl.store(o_ptrs, acc, mask=mask_1d)
if SAVE_INTERMEDIATE:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr = (
intermediate_conv_window_ptr
+ conv_state_batch_coord * stride_inter_seq
+ idx_token * stride_inter_step
+ idx_feats * stride_inter_dim
)
if KERNEL_WIDTH >= 2:
tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w)
if KERNEL_WIDTH >= 3:
tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w)
if KERNEL_WIDTH >= 4:
tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
intermediate_conv_window: Optional[torch.Tensor] = None,
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
):
"""
x: (batch, dim) or (batch, dim, seqlen)
[shape=2: single token prediction]
[shape=3: single or multiple tokens prediction]
conv_state: (..., dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the conv_state
starting at the index
@cache_seqlens % state_len.
conv_state_indices: (batch,), dtype int32
If not None, the conv_state is a larger tensor along the batch dim,
and we are selecting the batch coords specified by conv_state_indices.
Useful for a continuous batching scenario.
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: (batch, dim) or (batch, dim, seqlen)
"""
if validate_data:
assert cache_seqlens is None # not implemented yet - ok for vLLM
assert pad_slot_id is not None
assert x.stride(1) == 1
if isinstance(activation, bool):
activation = "silu" if activation is True else None
elif activation is not None:
assert activation in ["silu", "swish"]
unsqueeze = x.dim() == 2
if unsqueeze:
# make it (batch, dim, seqlen) with seqlen == 1
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
_, width = weight.shape
# conv_state: (..., dim, state_len), where state_len >= width - 1
num_cache_lines, _, state_len = conv_state.size()
if validate_data:
assert dim == weight.size(0)
assert (
conv_state.stride(-2) == 1
), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch,) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
assert cache_seqlens is None # not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out = x
stride_w_dim, stride_w_width = weight.stride()
stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen)
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0
)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):
return (
batch,
triton.cdiv(dim, META["BLOCK_N"]),
)
# prepare intermediate buffer strides if provided
if intermediate_conv_window is not None:
stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = (
intermediate_conv_window.stride(0),
intermediate_conv_window.stride(1),
intermediate_conv_window.stride(2),
intermediate_conv_window.stride(3),
)
else:
stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0
_causal_conv1d_update_kernel[grid](
# Pointers to matrices
x,
weight,
bias,
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
intermediate_conv_window if intermediate_conv_window is not None else x,
out,
# Matrix dimensions
batch,
dim,
seqlen,
state_len,
num_cache_lines,
# stride
stride_x_seq,
stride_x_dim,
stride_x_token,
stride_w_dim,
stride_w_width,
stride_istate_seq,
stride_istate_dim,
stride_istate_token,
stride_state_indices,
stride_inter_seq,
stride_inter_step,
stride_inter_dim,
stride_inter_win,
stride_o_seq,
stride_o_dim,
stride_o_token,
# others
pad_slot_id,
# META
HAS_BIAS=bias is not None,
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
NP2_STATELEN=np2_statelen,
USE_PAD_SLOT=pad_slot_id is not None,
BLOCK_N=256,
SAVE_INTERMEDIATE=intermediate_conv_window is not None,
)
if unsqueeze:
out = out.squeeze(-1)
return out
......@@ -125,16 +125,6 @@ class MambaPool:
device=device,
)
if speculative_num_draft_tokens is not None:
mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
),
dtype=conv_dtype,
device="cuda",
)
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.empty(
......@@ -149,11 +139,24 @@ class MambaPool:
dtype=ssm_dtype,
device="cuda",
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.empty(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
mixed_qkv_cache,
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
)
else:
self.mamba_cache = (conv_state, temporal_state)
......
import bisect
from typing import TYPE_CHECKING, Callable
import torch
import torch.nn.functional as F
from sglang.srt.layers.attention.fla.fused_recurrent import (
fused_recurrent_gated_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
set_global_graph_memory_pool,
)
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
class MambaStateUpdateCudaGraphRunner:
def __init__(self, eagle_worker: "EAGLEWorker"):
self.eagle_worker = eagle_worker
model_runner = eagle_worker.target_worker.model_runner
self.model_runner = model_runner
self.attn_backend = model_runner.attn_backend.attn_backend_list[1]
self.req_to_token_pool = self.attn_backend.req_to_token_pool
self.graphs = {}
self.output_buffers = {}
self.graph_input_buffer = None
self.stream = torch.cuda.Stream()
self.model = model_runner.model
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.max_bs = self.capture_bs[-1]
self.init_cuda_graph_state()
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
def init_cuda_graph_state(self):
self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache
self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2]
num_mamba_layers = self.mamba_cache[0].shape[0]
conv_dtype = torch.bfloat16
conv_shape = self.mamba_cache[0].shape[2]
total_token_number = self.max_accepted_tokens * self.max_bs
self.mixed_qkv_cache = torch.empty(
size=(
num_mamba_layers,
total_token_number,
conv_shape,
),
dtype=conv_dtype,
device="cuda",
)
self.query_start_loc = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.state_indices = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.has_initial_states = torch.ones(
self.max_bs, dtype=torch.bool, device="cuda"
)
def capture(self):
CudaGraphRunner.capture(self)
def capture_one_batch_size(self, bs: int, forward: Callable):
"""
Capture CUDA Graph for a typical workload
"""
graph = torch.cuda.CUDAGraph()
stream = self.stream
total_token_number = bs * self.max_accepted_tokens
mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number]
query_start_loc = self.query_start_loc[: bs + 1]
state_indices = self.state_indices[:bs]
has_initial_states = self.has_initial_states[:bs]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
conv_states = mamba_caches[0]
mamba_map = self.req_to_token_pool.mamba_map
def run_once():
for i in range(len(self.model.model.layers)):
layer = self.model.model.layers[i]
if not isinstance(layer, Qwen3HybridLinearDecoderLayer):
continue
conv_weights = layer.linear_attn.conv1d.weight.view(
layer.linear_attn.conv1d.weight.size(0),
layer.linear_attn.conv1d.weight.size(2),
)
layer_id = mamba_map[i]
causal_conv1d_fn(
mixed_qkvs[layer_id].transpose(0, 1),
conv_weights,
layer.linear_attn.conv1d.bias,
activation=layer.linear_attn.activation,
conv_states=conv_states[layer_id],
has_initial_state=has_initial_states,
cache_indices=state_indices,
query_start_loc=query_start_loc,
)
return None
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
set_global_graph_memory_pool(graph.pool())
return graph, out
def can_run(self, accepted_length):
bs = accepted_length.shape[0]
return bs <= self.max_bs
def replay_repare(self, accepted_length):
request_number = accepted_length.shape[0]
# QQ: step = spec num_draft token num
num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2]
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
query_start_loc = torch.cat(
[
torch.zeros(
1,
dtype=query_start_loc.dtype,
device=query_start_loc.device,
),
query_start_loc,
]
)
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
0
) < accepted_length.unsqueeze(1)
state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
_, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask]
self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs)
self.query_start_loc[: request_number + 1] = query_start_loc
self.query_start_loc[request_number + 1 :] = self.query_start_loc[
request_number
]
self.state_indices[:request_number] = state_indices_tensor
self.state_indices[request_number:] = -1
valid_mask = accepted_length > 0
if intermediate_state_cache is not None:
last_steps = (accepted_length - 1).to(torch.int64)
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
:, valid_state_indices, last_steps
].to(ssm_states.dtype)
def replay(self, accepted_length):
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = accepted_length.shape[0]
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
self.replay_repare(accepted_length)
# Replay
self.graphs[bs].replay()
......@@ -407,15 +407,6 @@ class EAGLEWorker(TpModelWorker):
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
if self.target_worker.model_runner.is_hybrid_gdn:
from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import (
MambaStateUpdateCudaGraphRunner,
)
self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner(
self
)
@property
def draft_model_runner(self):
return self.model_runner
......@@ -848,12 +839,9 @@ class EAGLEWorker(TpModelWorker):
)
+ 1
)
if self.cuda_graph_runner_for_target_verify.can_run(accepted_length):
self.cuda_graph_runner_for_target_verify.replay(accepted_length)
else:
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accepted_length, self.target_worker.model_runner.model
)
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accepted_length, self.target_worker.model_runner.model
)
if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment