Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
862dd76c
Unverified
Commit
862dd76c
authored
Feb 15, 2025
by
Ke Bao
Committed by
GitHub
Feb 15, 2025
Browse files
Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)
parent
fb4c9c3a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
437 additions
and
7 deletions
+437
-7
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-0
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+295
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-3
python/sglang/srt/speculative/eagle_worker.py
python/sglang/srt/speculative/eagle_worker.py
+7
-2
python/sglang/srt/speculative/spec_info.py
python/sglang/srt/speculative/spec_info.py
+11
-1
scripts/export_deepseek_nextn.py
scripts/export_deepseek_nextn.py
+113
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
862dd76c
...
@@ -98,6 +98,7 @@ class ModelConfig:
...
@@ -98,6 +98,7 @@ class ModelConfig:
if
(
if
(
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
"DeepseekV2ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLM"
in
self
.
hf_config
.
architectures
or
"DeepseekV3ForCausalLMNextN"
in
self
.
hf_config
.
architectures
):
):
self
.
head_dim
=
256
self
.
head_dim
=
256
self
.
attention_arch
=
AttentionArch
.
MLA
self
.
attention_arch
=
AttentionArch
.
MLA
...
...
python/sglang/srt/models/deepseek_nextn.py
0 → 100644
View file @
862dd76c
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only DeepSeek NextN Speculative Decoding."""
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm
import
_custom_ops
as
ops
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
class
DeepseekModelNextN
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
enable_tp
=
not
global_server_args_dict
[
"enable_dp_attention"
],
)
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
2
*
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
decoder
=
DeepseekV2DecoderLayer
(
config
,
0
,
quant_config
=
quant_config
,
is_nextn
=
True
)
self
.
shared_head
=
nn
.
Module
()
self
.
shared_head
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
hidden_states
=
input_embeds
hidden_states
=
self
.
eh_proj
(
torch
.
cat
(
(
self
.
enorm
(
hidden_states
),
self
.
hnorm
(
forward_batch
.
spec_info
.
hidden_states
),
),
dim
=-
1
,
)
)
residual
=
None
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
hidden_states
,
_
=
self
.
shared_head
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
DeepseekV3ForCausalLMNextN
(
DeepseekV3ForCausalLM
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
model
=
DeepseekModelNextN
(
config
,
quant_config
)
if
global_server_args_dict
[
"enable_dp_attention"
]:
self
.
model
.
shared_head
.
head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
else
:
self
.
model
.
shared_head
.
head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
model
.
shared_head
.
head
,
forward_batch
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
assert
num_nextn_layers
==
1
,
"Only 1 nextn layer is supportted"
assert
num_nextn_layers
==
self
.
config
.
num_hidden_layers
else
:
raise
ValueError
(
"num_nextn_predict_layers is not in the config"
)
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl
=
EPMoE
if
global_server_args_dict
[
"enable_ep_moe"
]
else
FusedMoE
expert_params_mapping
=
MoEImpl
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
,
)
nextn_layer_prefix
=
"model.layers.0"
nextn_spec_weight_names
=
[
"shared_head.head"
,
"shared_head.norm"
,
"eh_proj"
,
"embed_tokens"
,
"enorm"
,
"hnorm"
,
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
not
name
.
startswith
(
nextn_layer_prefix
):
continue
else
:
is_decoder
=
True
# For nextn specific weights
for
weight_name
in
nextn_spec_weight_names
:
if
weight_name
in
name
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model"
)
is_decoder
=
False
break
# For decoder layer weights
if
is_decoder
:
name
=
name
.
replace
(
nextn_layer_prefix
,
"model.decoder"
)
if
"rotary_emb.inv_freq"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip non-stacked layers and experts (experts handled below).
if
weight_name
not
in
name
:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts."
in
name
)
and
name
not
in
params_dict
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
not
global_server_args_dict
[
"disable_mla"
]:
self_attn
=
self
.
model
.
decoder
.
self_attn
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
w
=
ops
.
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
0
,
0
,
0
,
).
T
else
:
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
is_hip_
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
is_hip_
:
self_attn
.
w_scale
*=
2.0
EntryClass
=
[
DeepseekV3ForCausalLMNextN
]
python/sglang/srt/models/deepseek_v2.py
View file @
862dd76c
...
@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module):
...
@@ -519,6 +519,8 @@ class DeepseekV2AttentionMLA(nn.Module):
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if
(
if
(
forward_batch
.
forward_mode
.
is_extend
()
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
and
not
forward_batch
.
forward_mode
.
is_draft_extend
()
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
and
forward_batch
.
extend_prefix_lens
.
sum
()
==
0
):
):
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
return
self
.
forward_normal
(
positions
,
hidden_states
,
forward_batch
)
...
@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -680,6 +682,7 @@ class DeepseekV2DecoderLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
layer_id
:
int
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
is_nextn
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -731,7 +734,7 @@ class DeepseekV2DecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
)
)
if
(
if
is_nextn
or
(
config
.
n_routed_experts
is
not
None
config
.
n_routed_experts
is
not
None
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
>=
config
.
first_k_dense_replace
and
layer_id
%
config
.
moe_layer_freq
==
0
and
layer_id
%
config
.
moe_layer_freq
==
0
...
...
python/sglang/srt/server_args.py
View file @
862dd76c
...
@@ -262,14 +262,17 @@ class ServerArgs:
...
@@ -262,14 +262,17 @@ class ServerArgs:
)
)
# Speculative Decoding
# Speculative Decoding
if
self
.
speculative_algorithm
==
"EAGLE"
:
if
(
self
.
speculative_algorithm
==
"EAGLE"
or
self
.
speculative_algorithm
==
"NEXTN"
):
self
.
prefill_only_one_req
=
True
self
.
prefill_only_one_req
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_cuda_graph_padding
=
True
self
.
disable_radix_cache
=
True
self
.
disable_radix_cache
=
True
self
.
disable_overlap_schedule
=
True
self
.
disable_overlap_schedule
=
True
self
.
chunked_prefill_size
=
-
1
self
.
chunked_prefill_size
=
-
1
logger
.
info
(
logger
.
info
(
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using
eagle
speculative decoding."
f
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using
{
self
.
speculative_algorithm
}
speculative decoding."
)
)
# GGUF
# GGUF
...
@@ -705,7 +708,7 @@ class ServerArgs:
...
@@ -705,7 +708,7 @@ class ServerArgs:
parser
.
add_argument
(
parser
.
add_argument
(
"--speculative-algorithm"
,
"--speculative-algorithm"
,
type
=
str
,
type
=
str
,
choices
=
[
"EAGLE"
],
choices
=
[
"EAGLE"
,
"NEXTN"
],
help
=
"Speculative algorithm."
,
help
=
"Speculative algorithm."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
...
python/sglang/srt/speculative/eagle_worker.py
View file @
862dd76c
...
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
...
@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import (
fast_topk
,
fast_topk
,
select_top_k_tokens
,
select_top_k_tokens
,
)
)
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
...
@@ -57,11 +58,15 @@ class EAGLEWorker(TpModelWorker):
# Parse arguments
# Parse arguments
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
topk
=
server_args
.
speculative_eagle_topk
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
speculative_num_steps
=
server_args
.
speculative_num_steps
self
.
speculative_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
)
self
.
server_args
=
server_args
self
.
server_args
=
server_args
# Share the embedding and lm_head
# Share the embedding and lm_head
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
if
not
self
.
speculative_algorithm
.
is_nextn
():
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
embed
,
head
=
self
.
target_worker
.
model_runner
.
model
.
get_embed_and_head
()
self
.
model_runner
.
model
.
set_embed_and_head
(
embed
,
head
)
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
self
.
model_runner
.
server_args
.
disable_cuda_graph
=
backup_disable_cuda_graph
# Create multi-step attn backends and cuda graph runners
# Create multi-step attn backends and cuda graph runners
...
...
python/sglang/srt/speculative/spec_info.py
View file @
862dd76c
...
@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum):
...
@@ -5,18 +5,28 @@ class SpeculativeAlgorithm(IntEnum):
NONE
=
auto
()
NONE
=
auto
()
EAGLE
=
auto
()
EAGLE
=
auto
()
# NEXTN spec decoding is for DeepSeek V3/R1
# currently it's implemented based on EAGLE
NEXTN
=
auto
()
def
is_none
(
self
):
def
is_none
(
self
):
return
self
==
SpeculativeAlgorithm
.
NONE
return
self
==
SpeculativeAlgorithm
.
NONE
def
is_eagle
(
self
):
def
is_eagle
(
self
):
return
self
==
SpeculativeAlgorithm
.
EAGLE
return
self
==
SpeculativeAlgorithm
.
EAGLE
or
self
==
SpeculativeAlgorithm
.
NEXTN
def
is_nextn
(
self
):
return
self
==
SpeculativeAlgorithm
.
NEXTN
@
staticmethod
@
staticmethod
def
from_string
(
name
:
str
):
def
from_string
(
name
:
str
):
name_map
=
{
name_map
=
{
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"EAGLE"
:
SpeculativeAlgorithm
.
EAGLE
,
"NEXTN"
:
SpeculativeAlgorithm
.
NEXTN
,
None
:
SpeculativeAlgorithm
.
NONE
,
None
:
SpeculativeAlgorithm
.
NONE
,
}
}
if
name
is
not
None
:
name
=
name
.
upper
()
return
name_map
[
name
]
return
name_map
[
name
]
...
...
scripts/export_deepseek_nextn.py
0 → 100644
View file @
862dd76c
"""
Export NextN layer for DeepSeek-V3/R1 model. The exported model can be used for speculative decoding.
Usage:
python3 export_deepseek_nextn.py --input-dir /path/to/DeepSeek-V3 --output-dir /path/to/DeepSeek-V3-NextN
"""
import
argparse
import
json
import
os
import
shutil
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
from
transformers
import
AutoConfig
def
get_nexn_layer_id
(
config
):
if
not
hasattr
(
config
,
"num_hidden_layers"
):
raise
ValueError
(
"'num_hidden_layers' not found in model config."
)
return
config
.
num_hidden_layers
def
update_and_save_config
(
config
,
output_dir
):
new_config
=
config
.
to_dict
()
new_config
.
update
(
{
"num_hidden_layers"
:
0
,
"architectures"
:
[
"DeepseekV3ForCausalLMNextN"
],
}
)
with
open
(
os
.
path
.
join
(
output_dir
,
"config.json"
),
"w"
)
as
f
:
json
.
dump
(
new_config
,
f
,
indent
=
2
,
ensure_ascii
=
False
,
sort_keys
=
True
)
def
copy_non_safetensors_files
(
input_dir
,
output_dir
):
for
filename
in
os
.
listdir
(
input_dir
):
src_file_path
=
os
.
path
.
join
(
input_dir
,
filename
)
if
os
.
path
.
isfile
(
src_file_path
)
and
not
filename
.
endswith
(
".safetensors"
):
dst_file_path
=
os
.
path
.
join
(
output_dir
,
filename
)
shutil
.
copy2
(
src_file_path
,
dst_file_path
)
print
(
f
"All non-safetensors files have been copied to
{
output_dir
}
"
)
def
export_nextn_layer_parameters
(
input_dir
,
output_dir
,
nexn_layer_id
):
prefix
=
f
"model.layers.
{
nexn_layer_id
}
"
output_path
=
os
.
path
.
join
(
output_dir
,
"nextn_layer_parameters.safetensors"
)
params
=
{}
for
filename
in
os
.
listdir
(
input_dir
):
if
not
filename
.
endswith
(
".safetensors"
):
continue
file_path
=
os
.
path
.
join
(
input_dir
,
filename
)
print
(
f
"Processing:
{
filename
}
"
)
try
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
matching_keys
=
[
k
for
k
in
f
.
keys
()
if
k
.
startswith
(
prefix
)]
if
not
matching_keys
:
print
(
f
" No parameters starting with '
{
prefix
}
' found"
)
continue
for
key
in
matching_keys
:
new_key
=
key
.
replace
(
prefix
,
"model.layers.0"
)
params
[
new_key
]
=
f
.
get_tensor
(
key
)
except
Exception
as
e
:
print
(
f
" Error processing
{
filename
}
:
{
str
(
e
)
}
"
)
if
params
:
print
(
f
"Saving
{
len
(
params
)
}
parameters to
{
output_path
}
"
)
save_file
(
params
,
output_path
)
else
:
print
(
"No matching parameters found."
)
# Update safetensors index
index_path
=
os
.
path
.
join
(
output_dir
,
"model.safetensors.index.json"
)
print
(
f
"Updating safetensors index to
{
index_path
}
"
)
index_data
=
{
"weight_map"
:
{}}
for
key
in
params
:
index_data
[
"weight_map"
][
key
]
=
"nextn_layer_parameters.safetensors"
with
open
(
index_path
,
"w"
)
as
f
:
json
.
dump
(
index_data
,
f
,
indent
=
4
)
print
(
"All done."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Export NextN layer paramerters for DeepSeek-V3/R1"
)
parser
.
add_argument
(
"--input-dir"
,
type
=
str
,
required
=
True
,
help
=
"Input HF model directory."
,
)
parser
.
add_argument
(
"--output-dir"
,
type
=
str
,
required
=
True
,
help
=
"Output nextn model directory."
,
)
args
=
parser
.
parse_args
()
config
=
AutoConfig
.
from_pretrained
(
args
.
input_dir
,
trust_remote_code
=
True
)
assert
config
.
num_nextn_predict_layers
==
1
,
"Only 1 nextn layer is supported."
nextn_layer_id
=
get_nexn_layer_id
(
config
)
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
copy_non_safetensors_files
(
args
.
input_dir
,
args
.
output_dir
)
update_and_save_config
(
config
,
args
.
output_dir
)
export_nextn_layer_parameters
(
args
.
input_dir
,
args
.
output_dir
,
nextn_layer_id
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment