Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dd408ee4
Unverified
Commit
dd408ee4
authored
Apr 30, 2025
by
Ke Bao
Committed by
GitHub
Apr 29, 2025
Browse files
Auto set draft model path for MTP (#5793)
parent
9419e75d
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
115 additions
and
287 deletions
+115
-287
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+7
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-2
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+1
-257
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+74
-17
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+21
-11
No files found.
python/sglang/srt/configs/model_config.py
View file @
dd408ee4
...
@@ -47,6 +47,7 @@ class ModelConfig:
...
@@ -47,6 +47,7 @@ class ModelConfig:
dtype
:
str
=
"auto"
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
is_draft_model
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
model_path
=
model_path
self
.
model_path
=
model_path
...
@@ -85,6 +86,12 @@ class ModelConfig:
...
@@ -85,6 +86,12 @@ class ModelConfig:
else
:
else
:
enable_multimodal
=
True
enable_multimodal
=
True
if
(
is_draft_model
and
self
.
hf_config
.
architectures
[
0
]
==
"DeepseekV3ForCausalLM"
):
self
.
hf_config
.
architectures
[
0
]
=
"DeepseekV3ForCausalLMNextN"
# Check model type
# Check model type
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
is_embedding
self
.
hf_config
.
architectures
,
is_embedding
...
...
python/sglang/srt/managers/tp_worker.py
View file @
dd408ee4
...
@@ -71,6 +71,7 @@ class TpModelWorker:
...
@@ -71,6 +71,7 @@ class TpModelWorker:
enable_multimodal
=
server_args
.
enable_multimodal
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
quantization
=
server_args
.
quantization
,
is_draft_model
=
is_draft_worker
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
dd408ee4
...
@@ -692,9 +692,14 @@ class ModelRunner:
...
@@ -692,9 +692,14 @@ class ModelRunner:
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
if
self
.
use_mla_backend
:
if
self
.
use_mla_backend
:
num_layers
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
)
cell_size
=
(
cell_size
=
(
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
(
self
.
model_config
.
kv_lora_rank
+
self
.
model_config
.
qk_rope_head_dim
)
*
self
.
model_config
.
num_hidden
_layers
*
num
_layers
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
*
torch
.
_utils
.
_element_size
(
self
.
kv_cache_dtype
)
)
)
else
:
else
:
...
@@ -809,7 +814,11 @@ class ModelRunner:
...
@@ -809,7 +814,11 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
(
self
.
model_config
.
num_hidden_layers
if
not
self
.
is_draft_worker
else
self
.
model_config
.
hf_config
.
num_nextn_predict_layers
),
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
)
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
dd408ee4
...
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -177,263 +177,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
super
().
load_weights
(
weights
,
is_nextn
=
True
)
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
assert
num_nextn_layers
==
1
,
"Only 1 nextn layer is supportted"
assert
num_nextn_layers
==
self
.
config
.
num_hidden_layers
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
self
.
n_share_experts_fusion
>
0
:
logger
.
info
(
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
)
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
None
or
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
else
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale_inv"
,
"gate_proj.weight"
,
"gate_proj.weight_scale_inv"
,
"up_proj.weight"
,
"up_proj.weight_scale_inv"
,
]
names_to_remove
=
[]
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.0.mlp.shared_experts.
{
suffix
}
"
)
for
num_repeat
in
range
(
self
.
n_share_experts_fusion
):
weights_list
.
append
(
(
f
"model.layers.0."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
num_repeat
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
n_share_experts_fusion
,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
"shared_head.norm"
,
"eh_proj"
,
"enorm"
,
"hnorm"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Handle fused_qkv_a_proj
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
param
=
params_dict
[
param_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
self_attn
=
self
.
model
.
decoder
.
self_attn
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
).
T
else
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
0
,
0
,
0
,
).
T
else
:
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
if
w
.
dtype
==
torch
.
int8
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
# block-wise int8 need it
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
=
int8_block_dequant
(
weight
,
weight_scale
,
weight_block_size
).
to
(
torch
.
bfloat16
)
else
:
# channel-wise int8 need it
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
:
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
EntryClass
=
[
DeepseekV3ForCausalLMNextN
]
EntryClass
=
[
DeepseekV3ForCausalLMNextN
]
python/sglang/srt/models/deepseek_v2.py
View file @
dd408ee4
...
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1502,11 +1502,20 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
,
is_nextn
=
False
):
# Perform post-processing after loading weights
# Perform post-processing after loading weights
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
layer_ids
=
(
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
range
(
self
.
config
.
num_hidden_layers
)
if
not
is_nextn
else
[
self
.
config
.
num_hidden_layers
]
)
for
layer_id
in
layer_ids
:
self_attn
=
(
self
.
model
.
layers
[
layer_id
].
self_attn
if
not
is_nextn
else
self
.
model
.
decoder
.
self_attn
)
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
# AWQ compatible
if
_is_cuda
:
if
_is_cuda
:
...
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1612,7 +1621,20 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
use_deep_gemm_bmm
=
True
self_attn
.
use_deep_gemm_bmm
=
True
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
assert
num_nextn_layers
==
1
,
"Only 1 nextn layer is supportted"
# compatible with old design
nextn_layer_id
=
(
0
if
self
.
config
.
num_hidden_layers
==
1
else
self
.
config
.
num_hidden_layers
)
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
...
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1640,12 +1662,19 @@ class DeepseekV2ForCausalLM(nn.Module):
"up_proj.weight_scale_inv"
,
"up_proj.weight_scale_inv"
,
]
]
names_to_remove
=
[]
names_to_remove
=
[]
for
moe_layer
in
tqdm
(
moe_layers
=
(
range
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
self
.
config
.
moe_layer_freq
,
),
)
if
not
is_nextn
else
[
nextn_layer_id
]
)
for
moe_layer
in
tqdm
(
moe_layers
,
desc
=
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
desc
=
f
"Cloning
{
self
.
n_share_experts_fusion
}
"
"replicas of the shared expert into MoE"
,
"replicas of the shared expert into MoE"
,
):
):
...
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1686,18 +1715,46 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_spec_weight_names
=
[
"shared_head.norm"
,
"eh_proj"
,
"enorm"
,
"hnorm"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# TODO(HandH1998): Modify it when nextn is supported.
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
if
num_nextn_layers
>
0
and
name
.
startswith
(
"model.layers"
):
name_list
=
name
.
split
(
"."
)
name_list
=
name
.
split
(
"."
)
if
(
if
(
len
(
name_list
)
>=
3
len
(
name_list
)
>=
3
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
and
int
(
name_list
[
2
])
>=
self
.
config
.
num_hidden_layers
):
):
continue
continue
else
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
...
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1786,7 +1843,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
self
.
post_load_weights
()
self
.
post_load_weights
(
is_nextn
=
is_nextn
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/server_args.py
View file @
dd408ee4
...
@@ -22,7 +22,7 @@ import random
...
@@ -22,7 +22,7 @@ import random
import
tempfile
import
tempfile
from
typing
import
List
,
Literal
,
Optional
from
typing
import
List
,
Literal
,
Optional
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
from
sglang.srt.hf_transformers_utils
import
check_gguf_file
,
get_config
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
configure_ipv6
,
configure_ipv6
,
...
@@ -333,6 +333,14 @@ class ServerArgs:
...
@@ -333,6 +333,14 @@ class ServerArgs:
"eagle speculative decoding."
"eagle speculative decoding."
)
)
model_arch
=
get_model_arch
(
self
)
# Auto set draft_model_path DeepSeek-V3/R1
if
self
.
speculative_draft_model_path
is
None
and
model_arch
in
[
"DeepseekV3ForCausalLM"
]:
self
.
speculative_draft_model_path
=
self
.
model_path
# Auto choose parameters
# Auto choose parameters
if
self
.
speculative_num_steps
is
None
:
if
self
.
speculative_num_steps
is
None
:
assert
(
assert
(
...
@@ -343,7 +351,7 @@ class ServerArgs:
...
@@ -343,7 +351,7 @@ class ServerArgs:
self
.
speculative_num_steps
,
self
.
speculative_num_steps
,
self
.
speculative_eagle_topk
,
self
.
speculative_eagle_topk
,
self
.
speculative_num_draft_tokens
,
self
.
speculative_num_draft_tokens
,
)
=
auto_choose_speculative_params
(
self
)
)
=
auto_choose_speculative_params
(
model_arch
)
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
if
self
.
page_size
>
1
and
self
.
speculative_eagle_topk
>
1
:
self
.
speculative_eagle_topk
=
1
self
.
speculative_eagle_topk
=
1
...
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
...
@@ -1367,20 +1375,22 @@ class DeprecatedAction(argparse.Action):
raise
ValueError
(
self
.
help
)
raise
ValueError
(
self
.
help
)
def
auto_choose_speculative_params
(
self
:
ServerArgs
):
def
get_model_arch
(
args
:
ServerArgs
):
hf_config
=
get_config
(
args
.
model_path
,
trust_remote_code
=
args
.
trust_remote_code
,
revision
=
args
.
revision
,
model_override_args
=
json
.
loads
(
args
.
json_model_override_args
),
)
return
hf_config
.
architectures
[
0
]
def
auto_choose_speculative_params
(
arch
:
str
):
"""
"""
Automatically choose the parameters for speculative decoding.
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
"""
config_path
=
os
.
path
.
join
(
self
.
model_path
,
"config.json"
)
if
not
os
.
path
.
exists
(
config_path
):
raise
ValueError
(
f
"
{
config_path
}
is not found."
)
config
=
json
.
load
(
open
(
config_path
))
arch
=
config
.
get
(
"architectures"
,
[
"Unknown"
])[
0
]
if
arch
in
[
"LlamaForCausalLM"
]:
if
arch
in
[
"LlamaForCausalLM"
]:
# The default value for llama
# The default value for llama
return
(
5
,
4
,
8
)
return
(
5
,
4
,
8
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment