Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9fafa62d
Unverified
Commit
9fafa62d
authored
Mar 04, 2025
by
Ke Bao
Committed by
GitHub
Mar 03, 2025
Browse files
Share target model embed and head weights for nextn (#4033)
parent
146ac8df
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
47 additions
and
45 deletions
+47
-45
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+18
-16
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+11
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-4
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-15
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+1
-9
scripts/export_deepseek_nextn.py
scripts/export_deepseek_nextn.py
+2
-0
No files found.
python/sglang/srt/model_executor/forward_batch_info.py
View file @
9fafa62d
...
@@ -280,7 +280,8 @@ class ForwardBatch:
...
@@ -280,7 +280,8 @@ class ForwardBatch:
).
to
(
device
,
non_blocking
=
True
)
).
to
(
device
,
non_blocking
=
True
)
if
(
if
(
model_runner
.
server_args
.
attention_backend
!=
"torch_native"
model_runner
.
server_args
.
attention_backend
!=
"torch_native"
and
model_runner
.
server_args
.
speculative_algorithm
!=
"NEXTN"
# TODO: Fix triton kernel illegal memory access for EAGLE
and
model_runner
.
server_args
.
speculative_algorithm
!=
"EAGLE"
):
):
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
9fafa62d
...
@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self
.
model
=
DeepseekModelNextN
(
config
,
quant_config
)
self
.
model
=
DeepseekModelNextN
(
config
,
quant_config
)
if
global_server_args_dict
[
"enable_dp_attention"
]:
if
global_server_args_dict
[
"enable_dp_attention"
]:
self
.
model
.
shared_head
.
head
=
ReplicatedLinear
(
self
.
lm_
head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
vocab_size
,
config
.
vocab_size
,
bias
=
False
,
bias
=
False
,
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
else
:
else
:
self
.
model
.
shared_head
.
head
=
ParallelLMHead
(
self
.
lm_
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
...
@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
shared_head
.
head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_
head
,
forward_batch
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
nextn_layer_prefix
=
"model.layers.0"
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
nextn_spec_weight_names
=
[
"shared_head.head"
,
"shared_head.norm"
,
"shared_head.norm"
,
"eh_proj"
,
"eh_proj"
,
"embed_tokens"
,
"enorm"
,
"enorm"
,
"hnorm"
,
"hnorm"
,
]
]
...
@@ -180,7 +178,11 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
...
@@ -180,7 +178,11 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
continue
else
:
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
is_decoder
=
True
# For nextn specific weights
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
for
weight_name
in
nextn_spec_weight_names
:
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9fafa62d
...
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
if
is_hip_
:
if
is_hip_
:
self_attn
.
w_scale
*=
2.0
self_attn
.
w_scale
*=
2.0
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
pass
...
...
python/sglang/srt/server_args.py
View file @
9fafa62d
...
@@ -270,10 +270,11 @@ class ServerArgs:
...
@@ -270,10 +270,11 @@ class ServerArgs:
)
)
# Speculative Decoding
# Speculative Decoding
if
(
if
self
.
speculative_algorithm
==
"NEXTN"
:
self
.
speculative_algorithm
==
"EAGLE"
# NEXTN shares the same implementation of EAGLE
or
self
.
speculative_algorithm
==
"NEXTN"
self
.
speculative_algorithm
=
"EAGLE"
):
if
self
.
speculative_algorithm
==
"EAGLE"
:
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
self
.
prefill_only_one_req
=
True
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_cuda_graph_padding
=
True
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
9fafa62d
...
@@ -83,7 +83,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -83,7 +83,6 @@ class EAGLEWorker(TpModelWorker):
self
.
server_args
=
server_args
self
.
server_args
=
server_args
# Share the embedding and lm_head
# Share the embedding and lm_head
if
not
self
.
speculative_algorithm
.
is_nextn
():
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
if
server_args
.
speculative_token_map
is
not
None
:
if
server_args
.
speculative_token_map
is
not
None
:
head
=
head
.
clone
()
head
=
head
.
clone
()
...
@@ -94,12 +93,6 @@ class EAGLEWorker(TpModelWorker):
...
@@ -94,12 +93,6 @@ class EAGLEWorker(TpModelWorker):
else
:
else
:
self
.
hot_token_id
=
None
self
.
hot_token_id
=
None
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
else
:
if
server_args
.
speculative_token_map
is
not
None
:
raise
NotImplementedError
(
"NEXTN does not support speculative-token-map now"
)
self
.
hot_token_id
=
None
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
...
...
python/sglang/srt/speculative/spec_info.py
View file @
9fafa62d
...
@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum):
NONE
=
auto
()
NONE
=
auto
()
EAGLE
=
auto
()
EAGLE
=
auto
()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN
=
auto
()
def
is_none
(
self
):
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
return
self
==
SpeculativeAlgorithm
.
NONE
def
is_eagle
(
self
):
def
is_eagle
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE
or
self
==
SpeculativeAlgorithm
.
NEXTN
return
self
==
SpeculativeAlgorithm
.
EAGLE
def
is_nextn
(
self
):
return
self
==
SpeculativeAlgorithm
.
NEXTN
@
staticmethod
@
staticmethod
def
from_string
(
name
:
str
):
def
from_string
(
name
:
str
):
name_map
=
{
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"NEXTN"
:
SpeculativeAlgorithm
.
NEXTN
,
None
:
SpeculativeAlgorithm
.
NONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
}
if
name
is
not
None
:
if
name
is
not
None
:
...
...
scripts/export_deepseek_nextn.py
View file @
9fafa62d
...
@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
...
@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
continue
continue
for
key
in
matching_keys
:
for
key
in
matching_keys
:
if
"embed_tokens"
in
key
or
"shared_head.head"
in
key
:
continue
new_key
=
key
.
replace
(
prefix
,
"model.layers.0"
)
new_key
=
key
.
replace
(
prefix
,
"model.layers.0"
)
params
[
new_key
]
=
f
.
get_tensor
(
key
)
params
[
new_key
]
=
f
.
get_tensor
(
key
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment