Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3451fc32
Unverified
Commit
3451fc32
authored
Nov 01, 2025
by
Binyao Jiang
Committed by
GitHub
Nov 01, 2025
Browse files
[Feature] Qwen3-Next & FLA: Support MTP topk>1; Up to 6% faster (#11133)
Co-authored-by:
Stefan He
<
hebiaobuaa@gmail.com
>
parent
c550ab91
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
493 additions
and
79 deletions
+493
-79
python/sglang/srt/layers/attention/fla/fused_recurrent.py
python/sglang/srt/layers/attention/fla/fused_recurrent.py
+70
-0
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
...sglang/srt/layers/attention/hybrid_linear_attn_backend.py
+91
-19
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
...sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
+251
-54
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
+4
-0
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+30
-4
test/srt/models/test_qwen3_next_models.py
test/srt/models/test_qwen3_next_models.py
+47
-2
No files found.
python/sglang/srt/layers/attention/fla/fused_recurrent.py
View file @
3451fc32
...
...
@@ -330,12 +330,30 @@ def fused_recurrent_gated_delta_rule(
return
o
,
final_state
# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask
# retrieve_parent_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]
# e.g. for a sequence of length 4, the eagle tree attention structure is:
# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i
# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i
# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i
# Tree:
# 0
# / \
# 1 2
# /
# 3
# When calculating token 3's attention, it should attend to token 1 (parent) and token 0 (grand-parent)
# When calculating token 2's attention, it should attend to token 0 (parent)
@
triton
.
heuristics
(
{
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0_source"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"CACHE_INTERMEDIATE_STATES"
:
lambda
args
:
args
[
"intermediate_states_buffer"
]
is
not
None
,
"HAS_EAGLE_TREE_CUSTOM_ATTN_MASK"
:
lambda
args
:
args
[
"retrieve_parent_token_ptr"
]
is
not
None
,
}
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
...
...
@@ -352,7 +370,11 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
scale
,
intermediate_states_buffer
,
cache_steps
,
retrieve_parent_token_ptr
,
stride_retrieve_parent_token_seq
:
tl
.
constexpr
,
stride_retrieve_parent_token_token
:
tl
.
constexpr
,
T
,
NP2_T
:
tl
.
constexpr
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
...
...
@@ -367,6 +389,7 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
DISABLE_STATE_UPDATE
:
tl
.
constexpr
,
# whether to disable final state update
DISABLE_OUTPUT_CALCULATION
:
tl
.
constexpr
,
# whether to disable output calculation
CACHE_INTERMEDIATE_STATES
:
tl
.
constexpr
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
...
...
@@ -393,6 +416,16 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
p_g
=
g
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
token_indices
=
tl
.
arange
(
0
,
NP2_T
)
mask_retrieve
=
token_indices
<
T
retrieve_parent_token_base
=
(
retrieve_parent_token_ptr
+
(
i_n
*
stride_retrieve_parent_token_seq
)
+
token_indices
*
stride_retrieve_parent_token_token
)
parent_idx_tokens
=
tl
.
load
(
retrieve_parent_token_base
,
mask_retrieve
)
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
...
...
@@ -418,6 +451,24 @@ def fused_recurrent_gated_delta_rule_update_fwd_kernel(
step_idx
=
0
for
_
in
range
(
0
,
T
):
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
# step_idx = 0 should use the b_h from USE_INITIAL_STATE
if
step_idx
!=
0
and
cache_idx
>=
0
:
# when calculating current step's attention, load the state from the parent token
parent_step_idx
=
tl
.
sum
(
tl
.
where
(
token_indices
==
step_idx
,
parent_idx_tokens
,
0
)
)
step_offset
=
parent_step_idx
*
HV
*
K
*
V
cache_ptr
=
(
intermediate_states_buffer
+
cache_idx
*
cache_steps
*
HV
*
K
*
V
+
step_offset
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
)
b_h
=
tl
.
load
(
cache_ptr
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
...
...
@@ -498,6 +549,7 @@ def fused_recurrent_gated_delta_rule_update_fwd(
disable_output_calculation
:
bool
=
False
,
intermediate_states_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_steps
:
Optional
[
int
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
...
...
@@ -516,6 +568,16 @@ def fused_recurrent_gated_delta_rule_update_fwd(
grid
=
(
NK
,
NV
,
N
*
HV
)
# prepare retrieve next token buffer strides if provided
if
retrieve_parent_token
is
not
None
:
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
=
(
retrieve_parent_token
.
stride
(
0
),
retrieve_parent_token
.
stride
(
1
),
)
else
:
stride_retrieve_parent_token_seq
=
stride_retrieve_parent_token_token
=
0
NP2_T
=
triton
.
next_power_of_2
(
T
)
fused_recurrent_gated_delta_rule_update_fwd_kernel
[
grid
](
q
=
q
,
k
=
k
,
...
...
@@ -529,7 +591,11 @@ def fused_recurrent_gated_delta_rule_update_fwd(
scale
=
scale
,
intermediate_states_buffer
=
intermediate_states_buffer
,
cache_steps
=
0
if
cache_steps
is
None
else
cache_steps
,
retrieve_parent_token_ptr
=
retrieve_parent_token
,
stride_retrieve_parent_token_seq
=
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
=
stride_retrieve_parent_token_token
,
T
=
T
,
NP2_T
=
NP2_T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
...
...
@@ -568,6 +634,7 @@ class FusedRecurrentUpdateFunction(torch.autograd.Function):
disable_output_calculation
:
bool
=
False
,
intermediate_states_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_steps
:
Optional
[
int
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
):
o
=
fused_recurrent_gated_delta_rule_update_fwd
(
q
=
q
,
...
...
@@ -584,6 +651,7 @@ class FusedRecurrentUpdateFunction(torch.autograd.Function):
disable_output_calculation
=
disable_output_calculation
,
intermediate_states_buffer
=
intermediate_states_buffer
,
cache_steps
=
cache_steps
,
retrieve_parent_token
=
retrieve_parent_token
,
)
return
o
...
...
@@ -613,6 +681,7 @@ def fused_recurrent_gated_delta_rule_update(
disable_output_calculation
:
bool
=
False
,
intermediate_states_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_steps
:
Optional
[
int
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
cu_seqlens
is
not
None
:
if
q
.
shape
[
0
]
!=
1
:
...
...
@@ -649,5 +718,6 @@ def fused_recurrent_gated_delta_rule_update(
disable_output_calculation
,
intermediate_states_buffer
,
cache_steps
,
retrieve_parent_token
,
)
return
o
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
View file @
3451fc32
...
...
@@ -66,12 +66,19 @@ class MambaAttnBackendBase(AttentionBackend):
self
.
forward_metadata
:
ForwardMetadata
=
None
self
.
state_indices_list
=
[]
self
.
query_start_loc_list
=
[]
self
.
retrieve_next_token_list
=
[]
self
.
retrieve_next_sibling_list
=
[]
self
.
retrieve_parent_token_list
=
[]
self
.
cached_cuda_graph_decode_query_start_loc
:
torch
.
Tensor
=
None
self
.
cached_cuda_graph_verify_query_start_loc
:
torch
.
Tensor
=
None
def
_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
bs
=
forward_batch
.
batch_size
retrieve_next_token
=
None
retrieve_next_sibling
=
None
retrieve_parent_token
=
None
if
forward_batch
.
forward_mode
.
is_decode_or_idle
():
query_start_loc
=
torch
.
arange
(
0
,
bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
...
...
@@ -85,6 +92,11 @@ class MambaAttnBackendBase(AttentionBackend):
dtype
=
torch
.
int32
,
device
=
forward_batch
.
input_ids
.
device
,
)
if
forward_batch
.
spec_info
.
topk
>
1
:
retrieve_next_token
=
forward_batch
.
spec_info
.
retrive_next_token
retrieve_next_sibling
=
forward_batch
.
spec_info
.
retrive_next_sibling
retrieve_parent_token
=
torch
.
empty_like
(
retrieve_next_token
)
else
:
query_start_loc
=
torch
.
empty
(
(
bs
+
1
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
...
...
@@ -102,6 +114,9 @@ class MambaAttnBackendBase(AttentionBackend):
return
ForwardMetadata
(
query_start_loc
=
query_start_loc
,
mamba_cache_indices
=
mamba_cache_indices
,
retrieve_next_token
=
retrieve_next_token
,
retrieve_next_sibling
=
retrieve_next_sibling
,
retrieve_parent_token
=
retrieve_parent_token
,
)
def
init_forward_metadata
(
self
,
forward_batch
:
ForwardBatch
):
...
...
@@ -118,7 +133,7 @@ class MambaAttnBackendBase(AttentionBackend):
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
self
.
forward_metadata
=
self
.
_capture_metadata
(
bs
,
req_pool_indices
,
forward_mode
bs
,
req_pool_indices
,
forward_mode
,
spec_info
)
def
init_forward_metadata_replay_cuda_graph
(
...
...
@@ -140,7 +155,7 @@ class MambaAttnBackendBase(AttentionBackend):
assert
(
max_num_tokens
%
max_bs
==
0
),
f
"max_num_tokens=
{
max_num_tokens
}
must be divisible by max_bs=
{
max_bs
}
"
verify_step
=
max_num_tokens
/
max_bs
draft_token_num
=
max_num_tokens
/
/
max_bs
for
i
in
range
(
max_bs
):
self
.
state_indices_list
.
append
(
torch
.
full
(
...
...
@@ -150,19 +165,38 @@ class MambaAttnBackendBase(AttentionBackend):
self
.
query_start_loc_list
.
append
(
torch
.
empty
((
i
+
2
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
retrieve_next_token_list
.
append
(
torch
.
zeros
(
(
i
+
1
,
draft_token_num
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
retrieve_next_sibling_list
.
append
(
torch
.
zeros
(
(
i
+
1
,
draft_token_num
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
retrieve_parent_token_list
.
append
(
torch
.
zeros
(
(
i
+
1
,
draft_token_num
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
)
self
.
cached_cuda_graph_decode_query_start_loc
=
torch
.
arange
(
0
,
max_bs
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
self
.
cached_cuda_graph_verify_query_start_loc
=
torch
.
arange
(
0
,
max_bs
*
verify_step
+
1
,
step
=
verify_step
,
max_bs
*
draft_token_num
+
1
,
step
=
draft_token_num
,
dtype
=
torch
.
int32
,
device
=
self
.
device
,
)
def
_capture_metadata
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
if
forward_mode
.
is_decode_or_idle
():
self
.
query_start_loc_list
[
bs
-
1
].
copy_
(
...
...
@@ -176,10 +210,24 @@ class MambaAttnBackendBase(AttentionBackend):
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
mamba_indices
=
self
.
req_to_token_pool
.
get_mamba_indices
(
req_pool_indices
)
self
.
state_indices_list
[
bs
-
1
][:
len
(
mamba_indices
)].
copy_
(
mamba_indices
)
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
if
forward_mode
.
is_target_verify
()
and
spec_info
.
topk
>
1
:
# They are None during cuda graph capture so skip the copy_...
# self.retrieve_next_token_list[bs - 1].copy_(spec_info.retrive_next_token)
# self.retrieve_next_sibling_list[bs - 1].copy_(spec_info.retrive_next_sibling)
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
retrieve_next_token
=
self
.
retrieve_next_token_list
[
bs
-
1
],
retrieve_next_sibling
=
self
.
retrieve_next_sibling_list
[
bs
-
1
],
retrieve_parent_token
=
self
.
retrieve_parent_token_list
[
bs
-
1
],
)
else
:
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
)
def
_replay_metadata
(
self
,
...
...
@@ -224,10 +272,28 @@ class MambaAttnBackendBase(AttentionBackend):
else
:
raise
ValueError
(
f
"Invalid forward mode:
{
forward_mode
=
}
"
)
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
if
forward_mode
.
is_target_verify
()
and
spec_info
.
topk
>
1
:
bs_without_pad
=
spec_info
.
retrive_next_token
.
shape
[
0
]
# print(spec_info.retrive_next_token, spec_info.retrive_next_sibling)
self
.
retrieve_next_token_list
[
bs
-
1
][:
bs_without_pad
].
copy_
(
spec_info
.
retrive_next_token
)
self
.
retrieve_next_sibling_list
[
bs
-
1
][:
bs_without_pad
].
copy_
(
spec_info
.
retrive_next_sibling
)
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
retrieve_next_token
=
self
.
retrieve_next_token_list
[
bs
-
1
],
retrieve_next_sibling
=
self
.
retrieve_next_sibling_list
[
bs
-
1
],
retrieve_parent_token
=
self
.
retrieve_parent_token_list
[
bs
-
1
],
)
else
:
return
ForwardMetadata
(
query_start_loc
=
self
.
query_start_loc_list
[
bs
-
1
],
mamba_cache_indices
=
self
.
state_indices_list
[
bs
-
1
],
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
return
1
# Mamba attn does not use seq lens to index kv cache
...
...
@@ -557,6 +623,9 @@ class GDNAttnBackend(MambaAttnBackendBase):
query_start_loc
=
self
.
forward_metadata
.
query_start_loc
cache_indices
=
self
.
forward_metadata
.
mamba_cache_indices
retrieve_next_token
=
self
.
forward_metadata
.
retrieve_next_token
retrieve_next_sibling
=
self
.
forward_metadata
.
retrieve_next_sibling
retrieve_parent_token
=
self
.
forward_metadata
.
retrieve_parent_token
mamba_cache_params
=
self
.
req_to_token_pool
.
mamba2_layer_cache
(
layer_id
)
conv_states
=
mamba_cache_params
.
conv
...
...
@@ -591,6 +660,9 @@ class GDNAttnBackend(MambaAttnBackendBase):
activation
,
conv_state_indices
=
cache_indices
[:
batch_size
],
intermediate_conv_window
=
intermediate_conv_window_cache
,
retrieve_next_token
=
retrieve_next_token
,
retrieve_next_sibling
=
retrieve_next_sibling
,
retrieve_parent_token
=
retrieve_parent_token
,
)
mixed_qkv
=
(
mixed_qkv_processed
.
transpose
(
1
,
2
).
contiguous
().
view
(
seq_len
,
-
1
)
...
...
@@ -645,6 +717,7 @@ class GDNAttnBackend(MambaAttnBackendBase):
disable_state_update
=
True
,
intermediate_states_buffer
=
intermediate_state_cache
,
cache_steps
=
forward_batch
.
spec_info
.
draft_token_num
,
retrieve_parent_token
=
retrieve_parent_token
,
)
else
:
recurrent_state
=
ssm_states
[
cache_indices
]
...
...
@@ -694,7 +767,7 @@ class Mamba2AttnBackend(MambaAttnBackendBase):
forward_mode
:
ForwardMode
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
metadata
=
self
.
_capture_metadata
(
bs
,
req_pool_indices
,
forward_mode
)
metadata
=
self
.
_capture_metadata
(
bs
,
req_pool_indices
,
forward_mode
,
spec_info
)
self
.
forward_metadata
=
Mamba2Metadata
.
prepare_decode
(
metadata
.
query_start_loc
,
metadata
.
mamba_cache_indices
,
seq_lens
)
...
...
@@ -891,8 +964,8 @@ class HybridLinearAttnBackend(AttentionBackend):
**
kwargs
,
)
def
update_mamba_state_after_mtp_verify
(
self
,
accepted_
length
,
model
):
request_number
=
accepted_
length
.
shape
[
0
]
def
update_mamba_state_after_mtp_verify
(
self
,
accepted_
indices
,
model
):
request_number
=
accepted_
indices
.
shape
[
0
]
state_indices_tensor
=
(
self
.
linear_attn_backend
.
forward_metadata
.
mamba_cache_indices
[
...
...
@@ -910,12 +983,11 @@ class HybridLinearAttnBackend(AttentionBackend):
intermediate_conv_window_cache
=
mamba_caches
.
intermediate_conv_window
# SSM state updates (chunked to reduce peak memory)
valid_mask
=
accepted_
length
>
0
valid_mask
=
accepted_
indices
>
=
0
# Compute common indices once to avoid duplication
last_steps_all
=
(
accepted_length
-
1
).
to
(
torch
.
int64
)
valid_state_indices
=
state_indices_tensor
[
valid_mask
].
to
(
torch
.
int64
)
# [N]
last_steps
=
last_steps_all
[
valid_mask
].
to
(
torch
.
int64
)
# [N]
last_steps
=
accepted_indices
[
valid_mask
].
to
(
torch
.
int64
)
# [N]
# scatter into ssm_states at the chosen cache lines
ssm_states
[:,
valid_state_indices
,
:]
=
intermediate_state_cache
[
...
...
python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py
View file @
3451fc32
...
...
@@ -186,7 +186,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
)
mask
=
(
idx_tokens_conv
<
state_len
)[:,
None
]
&
(
idx_feats
<
dim
)[
None
,
:]
tl
.
debug_barrier
()
# NOTE: use this due to bug in Triton compiler
#
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl
.
store
(
conv_states_ptrs_target
,
new_conv_state
,
mask
)
else
:
...
...
@@ -221,7 +221,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
)
# token-index # token-index # feature-index
loaded_x
=
tl
.
load
(
x_ptrs
,
mask_x
,
0.0
)
tl
.
debug_barrier
()
# need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
#
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
new_conv_state
=
tl
.
where
(
mask
,
conv_state
,
loaded_x
)
# BUG in 'tl.where' which requires a barrier before this
...
...
@@ -552,6 +552,21 @@ def causal_conv1d_fn(
return
out
# HAS_EAGLE_TREE_CUSTOM_ATTN_MASK is added to support eagle tree attention mask
# retrieve_next_token_ptr: [N, NP2_T], retrieve_next_sibling_ptr: [N, NP2_T]
# e.g. for a sequence of length 4, the eagle tree attention structure is:
# retrieve_next_token=[1, 3, -1, -1] -> retrieve_next_token[i]: the 1st child token of token i
# retrieve_next_sibling=[-1, 2, -1, -1] -> retrieve_next_sibling[i]: the 1st tree sibling token of token i
# retrieve_parent_token=[n/a, 0, 0, 1] -> retrieve_parent_token[i]: the parent token of token i
# Tree:
# 0
# / \
# 1 2
# /
# 3
# When calculating token 3's convolution, it should conv to token 1 (parent) and token 0 (grand-parent)
# When calculating token 2's convolution, it should conv to token 0 (parent)
# This kernel is a fused kernel which will also produce retrieve_parent_token based on retrieve_next_token & retrieve_next_sibling
@
triton
.
jit
()
def
_causal_conv1d_update_kernel
(
# Pointers to matrices
...
...
@@ -563,6 +578,9 @@ def _causal_conv1d_update_kernel(
conv_state_indices_ptr
,
num_accepted_tokens_ptr
,
intermediate_conv_window_ptr
,
retrieve_next_token_ptr
,
retrieve_next_sibling_ptr
,
retrieve_parent_token_ptr
,
o_ptr
,
# (batch, dim, seqlen)
# Matrix dimensions
batch
:
int
,
...
...
@@ -584,6 +602,12 @@ def _causal_conv1d_update_kernel(
stride_inter_step
:
tl
.
constexpr
,
stride_inter_dim
:
tl
.
constexpr
,
stride_inter_win
:
tl
.
constexpr
,
stride_retrieve_next_token_seq
:
tl
.
constexpr
,
stride_retrieve_next_token_token
:
tl
.
constexpr
,
stride_retrieve_next_sibling_seq
:
tl
.
constexpr
,
stride_retrieve_next_sibling_token
:
tl
.
constexpr
,
stride_retrieve_parent_token_seq
:
tl
.
constexpr
,
stride_retrieve_parent_token_token
:
tl
.
constexpr
,
stride_o_seq
:
tl
.
constexpr
,
stride_o_dim
:
tl
.
constexpr
,
stride_o_token
:
tl
.
constexpr
,
...
...
@@ -596,9 +620,11 @@ def _causal_conv1d_update_kernel(
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
NP2_STATELEN
:
tl
.
constexpr
,
NP2_SEQLEN
:
tl
.
constexpr
,
USE_PAD_SLOT
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SAVE_INTERMEDIATE
:
tl
.
constexpr
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
tl
.
constexpr
,
):
# ruff: noqa: E501
idx_seq
=
tl
.
program_id
(
0
)
...
...
@@ -695,7 +721,7 @@ def _causal_conv1d_update_kernel(
&
(
idx_feats
<
dim
)[
None
,
:]
)
# token-index # token-index # feature-index
loaded_x
=
tl
.
load
(
x_ptrs
,
mask_x
,
0.0
)
tl
.
debug_barrier
()
#
tl.debug_barrier()
new_conv_state
=
tl
.
where
(
mask
,
conv_state
,
loaded_x
)
...
...
@@ -723,6 +749,24 @@ def _causal_conv1d_update_kernel(
# STEP 4:
# PRE-LOAD WEIGHTS
# first kernel column, configured for weights to handle BLOCK_N features in range
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
idx_tokens
=
tl
.
arange
(
0
,
NP2_SEQLEN
)
# [BLOCK_M]
# Update parent mapping for all tokens at once using vectorized operations
mask_retrieve
=
idx_tokens
<
seqlen
retrieve_next_token_base
=
(
retrieve_next_token_ptr
+
(
idx_seq
*
stride_retrieve_next_token_seq
)
+
idx_tokens
*
stride_retrieve_next_token_token
)
retrieve_next_tokens
=
tl
.
load
(
retrieve_next_token_base
,
mask_retrieve
)
retrieve_next_sibling_base
=
(
retrieve_next_sibling_ptr
+
(
idx_seq
*
stride_retrieve_next_sibling_seq
)
+
idx_tokens
*
stride_retrieve_next_sibling_token
)
retrieve_next_siblings
=
tl
.
load
(
retrieve_next_sibling_base
,
mask_retrieve
)
parent_idx_tokens
=
tl
.
zeros
((
NP2_SEQLEN
,),
dtype
=
tl
.
int32
)
w_base
=
w_ptr
+
(
idx_feats
*
stride_w_dim
)
# [BLOCK_N,]
mask_w
=
idx_feats
<
dim
if
KERNEL_WIDTH
>=
2
:
...
...
@@ -744,45 +788,162 @@ def _causal_conv1d_update_kernel(
for
idx_token
in
tl
.
static_range
(
seqlen
):
acc
=
acc_preload
matrix_w
=
w_col0
matrix_x
=
col0
for
j
in
tl
.
static_range
(
KERNEL_WIDTH
):
if
KERNEL_WIDTH
==
2
:
if
j
==
1
:
# KERNEL_WIDTH-1:
matrix_w
=
w_col1
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
# set the parent index of the next token in the eagle tree
# next token's parent is the current token
retrieve_next_token_idx
=
tl
.
sum
(
tl
.
where
(
idx_tokens
==
idx_token
,
retrieve_next_tokens
,
0
)
)
if
retrieve_next_token_idx
!=
-
1
:
# pad slot id
parent_idx_tokens
=
tl
.
where
(
idx_tokens
==
retrieve_next_token_idx
,
idx_token
,
parent_idx_tokens
,
)
# next token's parent is the parent of the current token
retrieve_sibling_token_idx
=
tl
.
sum
(
tl
.
where
(
idx_tokens
==
idx_token
,
retrieve_next_siblings
,
0
)
)
if
retrieve_sibling_token_idx
!=
-
1
:
# pad slot id
parent_idx_token
=
tl
.
sum
(
tl
.
where
(
idx_tokens
==
idx_token
,
parent_idx_tokens
,
0
)
)
parent_idx_tokens
=
tl
.
where
(
idx_tokens
==
retrieve_sibling_token_idx
,
parent_idx_token
,
parent_idx_tokens
,
)
# tl.device_print("am", parent_idx_tokens)
_idx_token
=
idx_token
x_ptrs_1d
=
x_base_1d
+
_idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
# convolution operation: itself * wcol[-1] + parent * wcol[-2] + grand-parent * wcol[-3] + ...
for
j
in
tl
.
static_range
(
KERNEL_WIDTH
):
if
KERNEL_WIDTH
==
2
:
if
j
==
0
:
matrix_w
=
w_col1
else
:
matrix_w
=
w_col0
elif
KERNEL_WIDTH
==
3
:
if
j
==
0
:
matrix_w
=
w_col2
elif
j
==
1
:
matrix_w
=
w_col1
else
:
matrix_w
=
w_col0
elif
KERNEL_WIDTH
==
4
:
if
j
==
0
:
matrix_w
=
w_col3
elif
j
==
1
:
matrix_w
=
w_col2
elif
j
==
2
:
matrix_w
=
w_col1
else
:
matrix_w
=
w_col0
if
SAVE_INTERMEDIATE
:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr
=
(
intermediate_conv_window_ptr
+
conv_state_batch_coord
*
stride_inter_seq
+
idx_token
*
stride_inter_step
+
idx_feats
*
stride_inter_dim
)
# store itself in KERNEL_WIDTH-2 slot, parent in KERNEL_WIDTH-3 slot, grand-parent in KERNEL_WIDTH-4 slot, ...
if
KERNEL_WIDTH
-
j
-
2
>=
0
:
tl
.
store
(
base_ptr
+
(
KERNEL_WIDTH
-
j
-
2
)
*
stride_inter_win
,
matrix_x
,
mask
=
mask_w
,
)
acc
+=
matrix_x
*
matrix_w
# move to parent for next iteration
if
_idx_token
>
0
:
_idx_token
=
tl
.
sum
(
tl
.
where
(
idx_tokens
==
_idx_token
,
parent_idx_tokens
,
0
)
)
x_ptrs_1d
=
x_base_1d
+
_idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
else
:
# no parent within the current chunk, load from prev conv state: col[-1] (idx 0's parent), col[-2] (idx 0's grand parent), ...
if
KERNEL_WIDTH
==
2
:
if
_idx_token
==
0
:
matrix_x
=
col0
elif
KERNEL_WIDTH
==
3
:
if
_idx_token
==
0
:
matrix_x
=
col1
else
:
matrix_x
=
col0
elif
KERNEL_WIDTH
==
4
:
if
_idx_token
==
0
:
matrix_x
=
col2
elif
_idx_token
==
-
1
:
matrix_x
=
col1
else
:
matrix_x
=
col0
_idx_token
=
_idx_token
-
1
else
:
matrix_w
=
w_col0
matrix_x
=
col0
for
j
in
tl
.
static_range
(
KERNEL_WIDTH
):
if
KERNEL_WIDTH
==
2
:
if
j
==
1
:
# KERNEL_WIDTH-1:
matrix_w
=
w_col1
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
elif
KERNEL_WIDTH
==
3
:
if
j
==
1
:
matrix_w
=
w_col1
matrix_x
=
col1
elif
j
==
2
:
matrix_w
=
w_col2
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
elif
KERNEL_WIDTH
==
4
:
if
j
==
1
:
matrix_w
=
w_col1
matrix_x
=
col1
elif
j
==
2
:
matrix_w
=
w_col2
matrix_x
=
col2
elif
j
==
3
:
matrix_w
=
w_col3
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
acc
+=
matrix_x
*
matrix_w
# [BLOCK_N]
if
KERNEL_WIDTH
==
2
:
col0
=
matrix_x
elif
KERNEL_WIDTH
==
3
:
if
j
==
1
:
matrix_w
=
w_col1
matrix_x
=
col1
elif
j
==
2
:
matrix_w
=
w_col2
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
col0
=
col1
col1
=
matrix_x
elif
KERNEL_WIDTH
==
4
:
if
j
==
1
:
matrix_w
=
w_col1
matrix_x
=
col1
elif
j
==
2
:
matrix_w
=
w_col2
matrix_x
=
col2
elif
j
==
3
:
matrix_w
=
w_col3
x_ptrs_1d
=
x_base_1d
+
idx_token
*
stride_x_token
# [BLOCK_N]
matrix_x
=
tl
.
load
(
x_ptrs_1d
,
mask
=
mask_x_1d
)
acc
+=
matrix_x
*
matrix_w
# [BLOCK_N]
if
KERNEL_WIDTH
==
2
:
col0
=
matrix_x
elif
KERNEL_WIDTH
==
3
:
col0
=
col1
col1
=
matrix_x
elif
KERNEL_WIDTH
==
4
:
col0
=
col1
col1
=
col2
col2
=
matrix_x
col0
=
col1
col1
=
col2
col2
=
matrix_x
if
SAVE_INTERMEDIATE
:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr
=
(
intermediate_conv_window_ptr
+
conv_state_batch_coord
*
stride_inter_seq
+
idx_token
*
stride_inter_step
+
idx_feats
*
stride_inter_dim
)
if
KERNEL_WIDTH
>=
2
:
tl
.
store
(
base_ptr
+
0
*
stride_inter_win
,
col0
,
mask
=
mask_w
)
if
KERNEL_WIDTH
>=
3
:
tl
.
store
(
base_ptr
+
1
*
stride_inter_win
,
col1
,
mask
=
mask_w
)
if
KERNEL_WIDTH
>=
4
:
tl
.
store
(
base_ptr
+
2
*
stride_inter_win
,
col2
,
mask
=
mask_w
)
if
SILU_ACTIVATION
:
acc
=
acc
/
(
1
+
tl
.
exp
(
-
acc
))
...
...
@@ -798,21 +959,15 @@ def _causal_conv1d_update_kernel(
tl
.
store
(
o_ptrs
,
acc
,
mask
=
mask_1d
)
if
SAVE_INTERMEDIATE
:
# Save the window state after consuming this token
# Layout: [seq(cache line), step, dim, win(K-1)]
base_ptr
=
(
intermediate_conv_window_ptr
+
conv_state_batch_coord
*
stride_inter_seq
+
idx_token
*
stride_inter_step
+
idx_feats
*
stride_inter_dim
# fuse: store calculated retrieve_parent_token to tensor
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
tl
.
store
(
retrieve_parent_token_ptr
+
idx_seq
*
stride_retrieve_parent_token_seq
+
idx_tokens
*
stride_retrieve_parent_token_token
,
parent_
idx_token
s
,
mask
=
mask_retrieve
,
)
if
KERNEL_WIDTH
>=
2
:
tl
.
store
(
base_ptr
+
0
*
stride_inter_win
,
col0
,
mask
=
mask_w
)
if
KERNEL_WIDTH
>=
3
:
tl
.
store
(
base_ptr
+
1
*
stride_inter_win
,
col1
,
mask
=
mask_w
)
if
KERNEL_WIDTH
>=
4
:
tl
.
store
(
base_ptr
+
2
*
stride_inter_win
,
col2
,
mask
=
mask_w
)
def
causal_conv1d_update
(
...
...
@@ -825,6 +980,9 @@ def causal_conv1d_update(
conv_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_conv_window
:
Optional
[
torch
.
Tensor
]
=
None
,
retrieve_next_token
:
Optional
[
torch
.
Tensor
]
=
None
,
retrieve_next_sibling
:
Optional
[
torch
.
Tensor
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
pad_slot_id
:
int
=
PAD_SLOT_ID
,
metadata
=
None
,
validate_data
=
False
,
...
...
@@ -888,7 +1046,7 @@ def causal_conv1d_update(
assert
cache_seqlens
is
None
# not needed for vLLM - circular buffer
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
out
=
x
out
=
torch
.
empty_like
(
x
)
stride_w_dim
,
stride_w_width
=
weight
.
stride
()
stride_x_seq
,
stride_x_dim
,
stride_x_token
=
x
.
stride
()
# X (batch, dim, seqlen)
...
...
@@ -903,6 +1061,7 @@ def causal_conv1d_update(
else
:
state_len
=
width
-
1
np2_statelen
=
triton
.
next_power_of_2
(
state_len
)
np2_seqlen
=
triton
.
next_power_of_2
(
seqlen
)
def
grid
(
META
):
return
(
...
...
@@ -921,6 +1080,33 @@ def causal_conv1d_update(
else
:
stride_inter_seq
=
stride_inter_step
=
stride_inter_dim
=
stride_inter_win
=
0
# prepare retrieve next token buffer strides if provided
if
retrieve_next_token
is
not
None
:
stride_retrieve_next_token_seq
,
stride_retrieve_next_token_token
=
(
retrieve_next_token
.
stride
(
0
),
retrieve_next_token
.
stride
(
1
),
)
else
:
stride_retrieve_next_token_seq
=
stride_retrieve_next_token_token
=
0
# prepare retrieve next sibling buffer strides if provided
if
retrieve_next_sibling
is
not
None
:
stride_retrieve_next_sibling_seq
,
stride_retrieve_next_sibling_token
=
(
retrieve_next_sibling
.
stride
(
0
),
retrieve_next_sibling
.
stride
(
1
),
)
else
:
stride_retrieve_next_sibling_seq
=
stride_retrieve_next_sibling_token
=
0
# prepare retrieve parent token buffer strides if provided
if
retrieve_parent_token
is
not
None
:
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
=
(
retrieve_parent_token
.
stride
(
0
),
retrieve_parent_token
.
stride
(
1
),
)
else
:
stride_retrieve_parent_token_seq
=
stride_retrieve_parent_token_token
=
0
_causal_conv1d_update_kernel
[
grid
](
# Pointers to matrices
x
,
...
...
@@ -931,6 +1117,9 @@ def causal_conv1d_update(
conv_state_indices
,
num_accepted_tokens
,
intermediate_conv_window
if
intermediate_conv_window
is
not
None
else
x
,
retrieve_next_token
,
retrieve_next_sibling
,
retrieve_parent_token
,
out
,
# Matrix dimensions
batch
,
...
...
@@ -952,6 +1141,12 @@ def causal_conv1d_update(
stride_inter_step
,
stride_inter_dim
,
stride_inter_win
,
stride_retrieve_next_token_seq
,
stride_retrieve_next_token_token
,
stride_retrieve_next_sibling_seq
,
stride_retrieve_next_sibling_token
,
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
,
stride_o_seq
,
stride_o_dim
,
stride_o_token
,
...
...
@@ -964,9 +1159,11 @@ def causal_conv1d_update(
IS_CONTINUOUS_BATCHING
=
conv_state_indices
is
not
None
,
IS_SPEC_DECODING
=
num_accepted_tokens
is
not
None
,
NP2_STATELEN
=
np2_statelen
,
NP2_SEQLEN
=
np2_seqlen
,
USE_PAD_SLOT
=
pad_slot_id
is
not
None
,
BLOCK_N
=
256
,
SAVE_INTERMEDIATE
=
intermediate_conv_window
is
not
None
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
=
retrieve_next_token
is
not
None
,
)
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
...
...
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
View file @
3451fc32
...
...
@@ -16,6 +16,7 @@
import
math
from
dataclasses
import
dataclass
from
typing
import
Optional
import
torch
...
...
@@ -26,6 +27,9 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class
ForwardMetadata
:
query_start_loc
:
torch
.
Tensor
mamba_cache_indices
:
torch
.
Tensor
retrieve_next_token
:
Optional
[
torch
.
Tensor
]
=
None
retrieve_next_sibling
:
Optional
[
torch
.
Tensor
]
=
None
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
@
dataclass
(
kw_only
=
True
)
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
3451fc32
...
...
@@ -694,19 +694,45 @@ class EAGLEWorker(TpModelWorker):
]
logits_output
.
hidden_states
=
logits_output
.
hidden_states
[
res
.
accepted_indices
]
# QQ: can be optimized
if
self
.
target_worker
.
model_runner
.
hybrid_gdn_config
is
not
None
:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length
=
(
torch
.
tensor
(
res
.
accept_length_per_req_cpu
,
device
=
logits_output
.
hidden_states
.
device
,
dtype
=
torch
.
int
32
,
dtype
=
torch
.
int
64
,
)
+
1
)
# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
# res.accepted_indices.shape[0] > 0 skips DP attn idle batch
if
spec_info
.
topk
>
1
and
res
.
accepted_indices
.
shape
[
0
]
>
0
:
# accepted_indices=[0,2,3,4,5,7,9,10,11], accepted_length=[4, 3, 2], cumulative_accepted_lengths=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_accepted_lengths[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_accepted_lengths - 1] = [4, 9, 11] (last token ID of each req)
# max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
cumulative_accepted_lengths
=
torch
.
cumsum
(
accepted_length
,
dim
=
0
)
req_start_positions
=
torch
.
cat
(
[
torch
.
zeros
(
1
,
dtype
=
cumulative_accepted_lengths
.
dtype
,
device
=
cumulative_accepted_lengths
.
device
,
),
cumulative_accepted_lengths
[:
-
1
],
]
)
first_token_indices_per_req
=
res
.
accepted_indices
[
req_start_positions
]
last_token_indices_per_req
=
res
.
accepted_indices
[
cumulative_accepted_lengths
-
1
]
max_relative_indices_per_req
=
(
last_token_indices_per_req
-
first_token_indices_per_req
)
else
:
max_relative_indices_per_req
=
accepted_length
-
1
self
.
target_worker
.
model_runner
.
attn_backend
.
update_mamba_state_after_mtp_verify
(
accepted_length
,
self
.
target_worker
.
model_runner
.
model
max_relative_indices_per_req
,
self
.
target_worker
.
model_runner
.
model
)
if
batch
.
return_logprob
:
...
...
test/srt/models/test_qwen3_next_models.py
View file @
3451fc32
...
...
@@ -59,11 +59,56 @@ class TestQwen3NextMTP(CustomTestCase):
"--speculative-algorithm"
,
"NEXTN"
,
"--speculative-num-steps"
,
"
1
"
,
"
3
"
,
"--speculative-eagle-topk"
,
"1"
,
"--speculative-num-draft-tokens"
,
"2"
,
"4"
,
"--mem-fraction-static"
,
"0.8"
,
"--tp"
,
"4"
,
],
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_gsm8k
(
self
):
args
=
SimpleNamespace
(
num_shots
=
5
,
data_path
=
None
,
num_questions
=
200
,
max_new_tokens
=
512
,
parallel
=
128
,
host
=
"http://127.0.0.1"
,
port
=
int
(
self
.
base_url
.
split
(
":"
)[
-
1
]),
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.93
)
class
TestQwen3NextMTPTopk
(
CustomTestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"Qwen/Qwen3-Next-80B-A3B-Instruct"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--speculative-algorithm"
,
"NEXTN"
,
"--speculative-num-steps"
,
"5"
,
"--speculative-eagle-topk"
,
"4"
,
"--speculative-num-draft-tokens"
,
"8"
,
"--mem-fraction-static"
,
"0.8"
,
"--tp"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment