Unverified Commit bd7eb020 authored by Binyao Jiang's avatar Binyao Jiang Committed by GitHub
Browse files

[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster (#10680)

parent 74cd6e39
......@@ -362,6 +362,7 @@ class MambaAttnBackend(AttentionBackend):
has_initial_state=has_initial_states,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)[:seq_len]
key_split_dim = key_dim // attn_tp_size
......
......@@ -23,6 +23,7 @@ def causal_conv1d_fn(
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
**kwargs,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
......
......@@ -2,7 +2,7 @@
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
from typing import Optional, Union
from typing import List, Optional, Union
import numpy as np
import torch
......@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
cache_indices_ptr, # conv_state_indices_ptr
has_initial_states_ptr,
query_start_loc_ptr,
batch_ptr,
token_chunk_offset_ptr,
o_ptr, # (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch: tl.int32, # actually padded_batch
dim: tl.constexpr,
seqlen: tl.int32, # cu_seqlen
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
......@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq = tl.load(batch_ptr + tl.program_id(0))
chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0))
idx_seq = tl.program_id(0)
chunk_offset = tl.program_id(1)
# BLOCK_N elements along the feature-dimension (channel)
idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N)
if idx_seq == pad_slot_id:
return
......@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
token_offset = BLOCK_M * chunk_offset
segment_len = min(BLOCK_M, seqlen - token_offset)
if segment_len <= 0:
return
# base of the sequence
x_base = (
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
......@@ -382,12 +382,13 @@ def causal_conv1d_fn(
bias: Union[torch.Tensor, None],
conv_states: torch.Tensor,
query_start_loc: torch.Tensor,
seq_lens_cpu: List[int],
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
metadata=None,
validate_data=False,
**kwargs,
):
"""support varlen + continuous batching when x is 2D tensor
......@@ -413,6 +414,8 @@ def causal_conv1d_fn(
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
seq_lens_cpu: (batch) int32
The sequence lengths of the sequences in the batch
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
......@@ -434,26 +437,7 @@ def causal_conv1d_fn(
if isinstance(activation, bool) and activation:
activation = "silu"
args = None
out = torch.empty_like(x)
if metadata is not None:
cu_seqlen = metadata.cu_seqlen
nums_dict = metadata.nums_dict
# x = metadata.x
args = nums_dict
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = np.diff(query_start_loc.to("cpu"))
args = seqlens
MAX_NUM_PROGRAMS = 1024
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
dim, cu_seqlen = x.shape
......@@ -461,7 +445,6 @@ def causal_conv1d_fn(
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
padded_batch = query_start_loc.size(0) - 1
stride_x_seq = 0
stride_x_dim = x.stride(0)
stride_x_token = x.stride(1)
......@@ -501,6 +484,7 @@ def causal_conv1d_fn(
assert query_start_loc is not None
assert query_start_loc.dim() == 1
assert x.stride(0) == 1 or x.stride(1) == 1
padded_batch = query_start_loc.size(0) - 1
if bias is not None:
assert bias.dim() == 1
assert dim == bias.size(0)
......@@ -516,78 +500,14 @@ def causal_conv1d_fn(
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if metadata is None:
def num_program(META, seqlens):
tot = 0
mlist = []
offsetlist = [] # type: ignore
nums = -(-seqlens // META["BLOCK_M"])
tot = nums.sum().item()
mlist = np.repeat(np.arange(len(nums)), nums)
for idx, num in enumerate(nums):
offsetlist.extend(
range(num)
) # chunk-idx if a sequence is split into multiple chunks
if META["batch_ptr"].nelement() < len(mlist):
newlen = len(mlist) + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= len(mlist):
META["batch_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(mlist))
)
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(offsetlist))
)
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
META["x_ptr"].device
)
return tot
else:
def num_program(META, nums_dict):
tot = nums_dict[META["BLOCK_M"]]["tot"]
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
"token_chunk_offset_ptr"
]
else:
if META["batch_ptr"].nelement() < mlist_len:
newlen = mlist_len + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= mlist_len:
META["batch_ptr"][0:mlist_len].copy_(mlist)
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
return tot
def grid(META):
max_seq_len = max(seq_lens_cpu)
return (
num_program(META, args),
len(seq_lens_cpu), # batch_size
(max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"],
triton.cdiv(dim, META["BLOCK_N"]),
)
if batch_ptr.device != x.device:
batch_ptr = batch_ptr.to(x.device)
token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device)
_causal_conv1d_fwd_kernel[grid](
# Pointers to matrices
x,
......@@ -597,11 +517,8 @@ def causal_conv1d_fn(
cache_indices,
has_initial_state,
query_start_loc,
batch_ptr,
token_chunk_offset_ptr,
out,
# Matrix dimensions
padded_batch,
dim,
cu_seqlen,
num_cache_lines,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment