Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bd7eb020
Unverified
Commit
bd7eb020
authored
Sep 20, 2025
by
Binyao Jiang
Committed by
GitHub
Sep 20, 2025
Browse files
[Performance] Qwen3-Next: optimize causal_conv1d_fn triton kernel - up to 9% faster (#10680)
parent
74cd6e39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
98 deletions
+17
-98
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+1
-0
python/sglang/srt/layers/attention/mamba/causal_conv1d.py
python/sglang/srt/layers/attention/mamba/causal_conv1d.py
+1
-0
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
...sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
+15
-98
No files found.
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
bd7eb020
...
...
@@ -362,6 +362,7 @@ class MambaAttnBackend(AttentionBackend):
has_initial_state
=
has_initial_states
,
cache_indices
=
cache_indices
,
query_start_loc
=
query_start_loc
,
seq_lens_cpu
=
forward_batch
.
extend_seq_lens_cpu
,
).
transpose
(
0
,
1
)[:
seq_len
]
key_split_dim
=
key_dim
//
attn_tp_size
...
...
python/sglang/srt/layers/attention/mamba/causal_conv1d.py
View file @
bd7eb020
...
...
@@ -23,6 +23,7 @@ def causal_conv1d_fn(
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
**
kwargs
,
):
"""
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
...
...
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
View file @
bd7eb020
...
...
@@ -2,7 +2,7 @@
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
from
typing
import
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
numpy
as
np
import
torch
...
...
@@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching
cache_indices_ptr
,
# conv_state_indices_ptr
has_initial_states_ptr
,
query_start_loc_ptr
,
batch_ptr
,
token_chunk_offset_ptr
,
o_ptr
,
# (dim, seqlen) - actually pointing to x_ptr
# Matrix dimensions
batch
:
tl
.
int32
,
# actually padded_batch
dim
:
tl
.
constexpr
,
seqlen
:
tl
.
int32
,
# cu_seqlen
num_cache_lines
:
tl
.
constexpr
,
# added to support vLLM larger cache lines
...
...
@@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# rather than mixing sequences - to make updating initial_states across sequences efficiently
# single-sequence id
idx_seq
=
tl
.
load
(
batch_ptr
+
tl
.
program_id
(
0
)
)
chunk_offset
=
tl
.
load
(
token_chunk_offset_ptr
+
tl
.
program_id
(
0
)
)
idx_seq
=
tl
.
program_id
(
0
)
chunk_offset
=
tl
.
program_id
(
1
)
# BLOCK_N elements along the feature-dimension (channel)
idx_feats
=
tl
.
program_id
(
1
)
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
idx_feats
=
tl
.
program_id
(
2
)
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
if
idx_seq
==
pad_slot_id
:
return
...
...
@@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
token_offset
=
BLOCK_M
*
chunk_offset
segment_len
=
min
(
BLOCK_M
,
seqlen
-
token_offset
)
if
segment_len
<=
0
:
return
# base of the sequence
x_base
=
(
x_ptr
+
sequence_start_index
*
stride_x_token
+
idx_feats
*
stride_x_dim
...
...
@@ -382,12 +382,13 @@ def causal_conv1d_fn(
bias
:
Union
[
torch
.
Tensor
,
None
],
conv_states
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
seq_lens_cpu
:
List
[
int
],
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
metadata
=
None
,
validate_data
=
False
,
**
kwargs
,
):
"""support varlen + continuous batching when x is 2D tensor
...
...
@@ -413,6 +414,8 @@ def causal_conv1d_fn(
[length(query_start_loc)-1 == batch]
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
seq_lens_cpu: (batch) int32
The sequence lengths of the sequences in the batch
cache_indices: (batch) int32
indicates the corresponding state index,
like so: conv_state = conv_states[cache_indices[batch_id]]
...
...
@@ -434,26 +437,7 @@ def causal_conv1d_fn(
if
isinstance
(
activation
,
bool
)
and
activation
:
activation
=
"silu"
args
=
None
out
=
torch
.
empty_like
(
x
)
if
metadata
is
not
None
:
cu_seqlen
=
metadata
.
cu_seqlen
nums_dict
=
metadata
.
nums_dict
# x = metadata.x
args
=
nums_dict
batch_ptr
=
metadata
.
batch_ptr
token_chunk_offset_ptr
=
metadata
.
token_chunk_offset_ptr
else
:
seqlens
=
np
.
diff
(
query_start_loc
.
to
(
"cpu"
))
args
=
seqlens
MAX_NUM_PROGRAMS
=
1024
batch_ptr
=
torch
.
full
(
(
MAX_NUM_PROGRAMS
,),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
# tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr
=
torch
.
full
(
(
MAX_NUM_PROGRAMS
,),
PAD_SLOT_ID
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)
# tracking BLOCK_M-based index in the sequence the Triton program is handling
is_channel_last
=
(
x
.
stride
(
0
)
==
1
)
&
(
x
.
stride
(
1
)
>
1
)
dim
,
cu_seqlen
=
x
.
shape
...
...
@@ -461,7 +445,6 @@ def causal_conv1d_fn(
state_len
=
width
-
1
np2_statelen
=
triton
.
next_power_of_2
(
state_len
)
padded_batch
=
query_start_loc
.
size
(
0
)
-
1
stride_x_seq
=
0
stride_x_dim
=
x
.
stride
(
0
)
stride_x_token
=
x
.
stride
(
1
)
...
...
@@ -501,6 +484,7 @@ def causal_conv1d_fn(
assert
query_start_loc
is
not
None
assert
query_start_loc
.
dim
()
==
1
assert
x
.
stride
(
0
)
==
1
or
x
.
stride
(
1
)
==
1
padded_batch
=
query_start_loc
.
size
(
0
)
-
1
if
bias
is
not
None
:
assert
bias
.
dim
()
==
1
assert
dim
==
bias
.
size
(
0
)
...
...
@@ -516,78 +500,14 @@ def causal_conv1d_fn(
assert
(
dim
,
width
)
==
weight
.
shape
assert
is_channel_last
,
"Need to run in channel-last layout"
if
metadata
is
None
:
def
num_program
(
META
,
seqlens
):
tot
=
0
mlist
=
[]
offsetlist
=
[]
# type: ignore
nums
=
-
(
-
seqlens
//
META
[
"BLOCK_M"
])
tot
=
nums
.
sum
().
item
()
mlist
=
np
.
repeat
(
np
.
arange
(
len
(
nums
)),
nums
)
for
idx
,
num
in
enumerate
(
nums
):
offsetlist
.
extend
(
range
(
num
)
)
# chunk-idx if a sequence is split into multiple chunks
if
META
[
"batch_ptr"
].
nelement
()
<
len
(
mlist
):
newlen
=
len
(
mlist
)
+
1
META
[
"batch_ptr"
].
resize_
(
newlen
).
fill_
(
PAD_SLOT_ID
)
META
[
"token_chunk_offset_ptr"
].
resize_
(
newlen
).
fill_
(
PAD_SLOT_ID
)
if
META
[
"batch_ptr"
].
nelement
()
>=
len
(
mlist
):
META
[
"batch_ptr"
][
0
:
len
(
mlist
)].
copy_
(
torch
.
from_numpy
(
np
.
array
(
mlist
))
)
META
[
"token_chunk_offset_ptr"
][
0
:
len
(
mlist
)].
copy_
(
torch
.
from_numpy
(
np
.
array
(
offsetlist
))
)
META
[
"batch_ptr"
]
=
META
[
"batch_ptr"
].
to
(
META
[
"x_ptr"
].
device
)
META
[
"token_chunk_offset_ptr"
]
=
META
[
"token_chunk_offset_ptr"
].
to
(
META
[
"x_ptr"
].
device
)
return
tot
else
:
def
num_program
(
META
,
nums_dict
):
tot
=
nums_dict
[
META
[
"BLOCK_M"
]][
"tot"
]
mlist
=
nums_dict
[
META
[
"BLOCK_M"
]][
"mlist"
]
mlist_len
=
nums_dict
[
META
[
"BLOCK_M"
]][
"mlist_len"
]
offsetlist
=
nums_dict
[
META
[
"BLOCK_M"
]][
"offsetlist"
]
if
nums_dict
[
META
[
"BLOCK_M"
]][
"batch_ptr"
]
is
not
None
:
META
[
"batch_ptr"
]
=
nums_dict
[
META
[
"BLOCK_M"
]][
"batch_ptr"
]
META
[
"token_chunk_offset_ptr"
]
=
nums_dict
[
META
[
"BLOCK_M"
]][
"token_chunk_offset_ptr"
]
else
:
if
META
[
"batch_ptr"
].
nelement
()
<
mlist_len
:
newlen
=
mlist_len
+
1
META
[
"batch_ptr"
].
resize_
(
newlen
).
fill_
(
PAD_SLOT_ID
)
META
[
"token_chunk_offset_ptr"
].
resize_
(
newlen
).
fill_
(
PAD_SLOT_ID
)
if
META
[
"batch_ptr"
].
nelement
()
>=
mlist_len
:
META
[
"batch_ptr"
][
0
:
mlist_len
].
copy_
(
mlist
)
META
[
"token_chunk_offset_ptr"
][
0
:
mlist_len
].
copy_
(
offsetlist
)
return
tot
def
grid
(
META
):
max_seq_len
=
max
(
seq_lens_cpu
)
return
(
num_program
(
META
,
args
),
len
(
seq_lens_cpu
),
# batch_size
(
max_seq_len
+
META
[
"BLOCK_M"
]
-
1
)
//
META
[
"BLOCK_M"
],
triton
.
cdiv
(
dim
,
META
[
"BLOCK_N"
]),
)
if
batch_ptr
.
device
!=
x
.
device
:
batch_ptr
=
batch_ptr
.
to
(
x
.
device
)
token_chunk_offset_ptr
=
token_chunk_offset_ptr
.
to
(
x
.
device
)
_causal_conv1d_fwd_kernel
[
grid
](
# Pointers to matrices
x
,
...
...
@@ -597,11 +517,8 @@ def causal_conv1d_fn(
cache_indices
,
has_initial_state
,
query_start_loc
,
batch_ptr
,
token_chunk_offset_ptr
,
out
,
# Matrix dimensions
padded_batch
,
dim
,
cu_seqlen
,
num_cache_lines
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment