Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9fafa62d
Unverified
Commit
9fafa62d
authored
Mar 04, 2025
by
Ke Bao
Committed by
GitHub
Mar 03, 2025
Browse files
Share target model embed and head weights for nextn (#4033)
parent
146ac8df
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
47 additions
and
45 deletions
+47
-45
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-1
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+18
-16
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+11
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+5
-4
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+8
-15
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+1
-9
scripts/export_deepseek_nextn.py
scripts/export_deepseek_nextn.py
+2
-0
No files found.
python/sglang/srt/model_executor/forward_batch_info.py
View file @
9fafa62d
...
...
@@ -280,7 +280,8 @@ class ForwardBatch:
).
to
(
device
,
non_blocking
=
True
)
if
(
model_runner
.
server_args
.
attention_backend
!=
"torch_native"
and
model_runner
.
server_args
.
speculative_algorithm
!=
"NEXTN"
# TODO: Fix triton kernel illegal memory access for EAGLE
and
model_runner
.
server_args
.
speculative_algorithm
!=
"EAGLE"
):
ret
.
extend_num_tokens
=
batch
.
extend_num_tokens
positions
,
ret
.
extend_start_loc
=
compute_position_triton
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
9fafa62d
...
...
@@ -116,14 +116,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self
.
model
=
DeepseekModelNextN
(
config
,
quant_config
)
if
global_server_args_dict
[
"enable_dp_attention"
]:
self
.
model
.
shared_head
.
head
=
ReplicatedLinear
(
self
.
lm_
head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
else
:
self
.
model
.
shared_head
.
head
=
ParallelLMHead
(
self
.
lm_
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
...
...
@@ -139,7 +139,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
shared_head
.
head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_
head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
...
...
@@ -168,10 +168,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
"shared_head.head"
,
"shared_head.norm"
,
"eh_proj"
,
"embed_tokens"
,
"enorm"
,
"hnorm"
,
]
...
...
@@ -180,17 +178,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
for
name
,
loaded_weight
in
weights
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
else
:
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
# Use shared head and embed weights from target model
if
"shared_head.head"
in
name
or
"embed_tokens"
in
name
:
continue
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
9fafa62d
...
...
@@ -1179,6 +1179,17 @@ class DeepseekV2ForCausalLM(nn.Module):
if
is_hip_
:
self_attn
.
w_scale
*=
2.0
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
def
set_embed_and_head
(
self
,
embed
,
head
):
del
self
.
model
.
embed_tokens
.
weight
del
self
.
lm_head
.
weight
self
.
model
.
embed_tokens
.
weight
=
embed
self
.
lm_head
.
weight
=
head
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
pass
...
...
python/sglang/srt/server_args.py
View file @
9fafa62d
...
...
@@ -270,10 +270,11 @@ class ServerArgs:
)
# Speculative Decoding
if
(
self
.
speculative_algorithm
==
"EAGLE"
or
self
.
speculative_algorithm
==
"NEXTN"
):
if
self
.
speculative_algorithm
==
"NEXTN"
:
# NEXTN shares the same implementation of EAGLE
self
.
speculative_algorithm
=
"EAGLE"
if
self
.
speculative_algorithm
==
"EAGLE"
:
self
.
disable_overlap_schedule
=
True
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
9fafa62d
...
...
@@ -83,23 +83,16 @@ class EAGLEWorker(TpModelWorker):
self
.
server_args
=
server_args
# Share the embedding and lm_head
if
not
self
.
speculative_algorithm
.
is_nextn
():
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
if
server_args
.
speculative_token_map
is
not
None
:
head
=
head
.
clone
()
self
.
hot_token_id
=
torch
.
tensor
(
self
.
hot_token_id
,
dtype
=
torch
.
int32
,
device
=
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
else
:
self
.
hot_token_id
=
None
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
if
server_args
.
speculative_token_map
is
not
None
:
head
=
head
.
clone
()
self
.
hot_token_id
=
torch
.
tensor
(
self
.
hot_token_id
,
dtype
=
torch
.
int32
,
device
=
head
.
device
)
head
.
data
=
head
.
data
[
self
.
hot_token_id
]
else
:
if
server_args
.
speculative_token_map
is
not
None
:
raise
NotImplementedError
(
"NEXTN does not support speculative-token-map now"
)
self
.
hot_token_id
=
None
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
...
...
python/sglang/srt/speculative/spec_info.py
View file @
9fafa62d
...
...
@@ -5,24 +5,16 @@ class SpeculativeAlgorithm(IntEnum):
NONE
=
auto
()
EAGLE
=
auto
()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN
=
auto
()
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
def
is_eagle
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE
or
self
==
SpeculativeAlgorithm
.
NEXTN
def
is_nextn
(
self
):
return
self
==
SpeculativeAlgorithm
.
NEXTN
return
self
==
SpeculativeAlgorithm
.
EAGLE
@
staticmethod
def
from_string
(
name
:
str
):
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"NEXTN"
:
SpeculativeAlgorithm
.
NEXTN
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
if
name
is
not
None
:
...
...
scripts/export_deepseek_nextn.py
View file @
9fafa62d
...
...
@@ -62,6 +62,8 @@ def export_nextn_layer_parameters(input_dir, output_dir, nextn_layer_id):
continue
for
key
in
matching_keys
:
if
"embed_tokens"
in
key
or
"shared_head.head"
in
key
:
continue
new_key
=
key
.
replace
(
prefix
,
"model.layers.0"
)
params
[
new_key
]
=
f
.
get_tensor
(
key
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment