Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
dd408ee4
"vscode:/vscode.git/clone" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
Unverified
Commit
dd408ee4
authored
Apr 30, 2025
by
Ke Bao
Committed by
GitHub
Apr 29, 2025
Browse files
Auto set draft model path for MTP (#5793)
parent
9419e75d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
287 deletions
+115
-287
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-2
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+1
-257
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+74
-17
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+21
-11
No files found.
python/sglang/srt/configs/model_config.py
View file @
dd408ee4
...
...
@@ -47,6 +47,7 @@ class ModelConfig:
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
is_draft_model
:
bool
=
False
,
)
->
None
:
self
.
model_path
=
model_path
...
...
@@ -85,6 +86,12 @@ class ModelConfig:
else
:
enable_multimodal
=
True
if
(
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
):
self
.
hf_config
.
architectures
[
0
]
=
"DeepseekV3ForCausalLMNextN"
# Check model type
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
is_embedding
...
...
python/sglang/srt/managers/tp_worker.py
View file @
dd408ee4
...
...
@@ -71,6 +71,7 @@ class TpModelWorker:
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
is_draft_model
=
is_draft_worker
,
)
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
dd408ee4
...
...
@@ -692,9 +692,14 @@ class ModelRunner:
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
if
self
.
use_mla_backend
:
num_layers
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
)
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
self
.
model_config
.
num_hidden
_layers
*
num
_layers
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
)
else
:
...
...
@@ -809,7 +814,11 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
dd408ee4
...
...
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
assert
num_nextn_layers
==
1
,
"Only 1 nextn layer is supportted"
assert
num_nextn_layers
==
self
.
config
.
num_hidden_layers
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
self
.
n_share_experts_fusion
>
0
:
logger
.
info
(
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
None
or
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
else
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale_inv"
,
"gate_proj.weight"
,
"gate_proj.weight_scale_inv"
,
"up_proj.weight"
,
"up_proj.weight_scale_inv"
,
]
names_to_remove
=
[]
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.0.mlp.shared_experts.
{
suffix
}
"
)
for
num_repeat
in
range
(
self
.
n_share_experts_fusion
):
weights_list
.
append
(
(
f
"model.layers.0."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
num_repeat
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
n_share_experts_fusion
,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
"shared_head.norm"
,
"eh_proj"
,
"enorm"
,
"hnorm"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Handle fused_qkv_a_proj
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
self_attn
=
self
.
model
.
decoder
.
self_attn
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
).
T
else
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
0
,
0
,
0
,
).
T
else
:
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
if
w
.
dtype
==
torch
.
int8
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
# block-wise int8 need it
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
=
int8_block_dequant
(
weight
,
weight_scale
,
weight_block_size
).
to
(
torch
.
bfloat16
)
else
:
# channel-wise int8 need it
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
:
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
super
().
load_weights
(
weights
,
is_nextn
=
True
)
EntryClass
=
[
DeepseekV3ForCausalLMNextN
]
python/sglang/srt/models/deepseek_v2.py
View file @
dd408ee4
...
...
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
is_nextn
=
False
):
# Perform post-processing after loading weights
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
layer_ids
=
(
range
(
self
.
config
.
num_hidden_layers
)
if
not
is_nextn
else
[
self
.
config
.
num_hidden_layers
]
)
for
layer_id
in
layer_ids
:
self_attn
=
(
self
.
model
.
layers
[
layer_id
].
self_attn
if
not
is_nextn
else
self
.
model
.
decoder
.
self_attn
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
...
...
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
use_deep_gemm_bmm
=
True
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
assert
num_nextn_layers
==
1
,
"Only 1 nextn layer is supportted"
# compatible with old design
nextn_layer_id
=
(
0
if
self
.
config
.
num_hidden_layers
==
1
else
self
.
config
.
num_hidden_layers
)
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
...
...
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.weight_scale_inv"
,
]
names_to_remove
=
[]
for
moe_layer
in
tqdm
(
moe_layers
=
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
),
)
if
not
is_nextn
else
[
nextn_layer_id
]
)
for
moe_layer
in
tqdm
(
moe_layers
,
desc
=
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
"replicas of the shared expert into MoE"
,
):
...
...
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_spec_weight_names
=
[
"shared_head.norm"
,
"eh_proj"
,
"enorm"
,
"hnorm"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
# TODO(HandH1998): Modify it when nextn is supported.
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
if
(
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
continue
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
if
(
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
continue
else
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
...
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
weight_loader
(
param
,
loaded_weight
)
self
.
post_load_weights
()
self
.
post_load_weights
(
is_nextn
=
is_nextn
)
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/server_args.py
View file @
dd408ee4
...
...
@@ -22,7 +22,7 @@ import random
import
tempfile
from
typing
import
List
,
Literal
,
Optional
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.utils
import
(
configure_ipv6
,
...
...
@@ -333,6 +333,14 @@ class ServerArgs:
"eagle speculative decoding."
)
model_arch
=
get_model_arch
(
self
)
# Auto set draft_model_path DeepSeek-V3/R1
if
self
.
speculative_draft_model_path
is
None
and
model_arch
in
[
"DeepseekV3ForCausalLM"
]:
self
.
speculative_draft_model_path
=
self
.
model_path
# Auto choose parameters
if
self
.
speculative_num_steps
is
None
:
assert
(
...
...
@@ -343,7 +351,7 @@ class ServerArgs:
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
self
)
)
=
auto_choose_speculative_params
(
model_arch
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
...
...
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
raise
ValueError
(
self
.
help
)
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
def
get_model_arch
(
args
:
ServerArgs
):
hf_config
=
get_config
(
args
.
model_path
,
trust_remote_code
=
args
.
trust_remote_code
,
revision
=
args
.
revision
,
model_override_args
=
json
.
loads
(
args
.
json_model_override_args
),
)
return
hf_config
.
architectures
[
0
]
def
auto_choose_speculative_params
(
arch
:
str
):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
config_path
=
os
.
path
.
join
(
self
.
model_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
config_path
):
raise
ValueError
(
f
"
{
config_path
}
is not found."
)
config
=
json
.
load
(
open
(
config_path
))
arch
=
config
.
get
(
"architectures"
,
[
"Unknown"
])[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
return
(
5
,
4
,
8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment