Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
51cdd81f
Unverified
Commit
51cdd81f
authored
May 30, 2025
by
Zilin Zhu
Committed by
GitHub
May 29, 2025
Browse files
[fix][RL] Fix DeepSeekV3ForCausalLM.post_load_weights for multiple update weight (#6265)
parent
73def253
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
14 deletions
+47
-14
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+39
-14
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+8
-0
No files found.
python/sglang/srt/models/deepseek_v2.py
View file @
51cdd81f
...
@@ -92,6 +92,7 @@ from sglang.srt.utils import (
...
@@ -92,6 +92,7 @@ from sglang.srt.utils import (
BumpAllocator
,
BumpAllocator
,
DeepEPMode
,
DeepEPMode
,
add_prefix
,
add_prefix
,
bind_or_assign
,
get_bool_env_var
,
get_bool_env_var
,
get_int_env_var
,
get_int_env_var
,
is_cuda
,
is_cuda
,
...
@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1713,14 +1714,23 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
post_load_weights
(
self
,
is_nextn
=
False
):
def
post_load_weights
(
self
,
is_nextn
=
False
,
weight_names
=
None
):
# Perform post-processing after loading weights
# Perform post-processing after loading weights
layer_ids
=
(
if
is_nextn
:
range
(
self
.
config
.
num_hidden_layers
)
layer_ids
=
[
self
.
config
.
num_hidden_layers
]
if
not
is_nextn
else
:
else
[
self
.
config
.
num_hidden_layers
]
if
weight_names
is
None
:
)
layer_ids
=
range
(
self
.
config
.
num_hidden_layers
)
else
:
layer_ids
=
set
()
for
name
in
weight_names
:
if
"kv_b_proj"
in
name
:
layer_id
=
int
(
name
.
split
(
"."
)[
2
])
# filter the nextn layer.
if
layer_id
!=
self
.
config
.
num_hidden_layers
:
layer_ids
.
add
(
layer_id
)
for
layer_id
in
layer_ids
:
for
layer_id
in
layer_ids
:
self_attn
=
(
self_attn
=
(
self
.
model
.
layers
[
layer_id
].
self_attn
self
.
model
.
layers
[
layer_id
].
self_attn
...
@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1830,13 +1840,19 @@ class DeepseekV2ForCausalLM(nn.Module):
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
if
not
use_deep_gemm_bmm
:
if
not
use_deep_gemm_bmm
:
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_kc
=
bind_or_assign
(
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_kc
,
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
)
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
().
transpose
(
1
,
2
)
)
if
(
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
and
self_attn
.
w_scale
is
None
):
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
self_attn
.
w_scale
=
bind_or_assign
(
self_attn
.
w_scale
,
self_attn
.
kv_b_proj
.
weight_scale
)
if
_is_hip
:
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
self_attn
.
w_scale
*=
2.0
else
:
else
:
...
@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1845,10 +1861,16 @@ class DeepseekV2ForCausalLM(nn.Module):
ws_kc
,
ws_vc
=
block_scale
.
unflatten
(
ws_kc
,
ws_vc
=
block_scale
.
unflatten
(
0
,
(
-
1
,
(
num_tiles_k
+
num_tiles_n
))
0
,
(
-
1
,
(
num_tiles_k
+
num_tiles_n
))
).
split
([
num_tiles_k
,
num_tiles_n
],
dim
=
1
)
).
split
([
num_tiles_k
,
num_tiles_n
],
dim
=
1
)
self_attn
.
w_scale_k
=
ws_kc
.
transpose
(
1
,
2
).
contiguous
()
self_attn
.
w_scale_k
=
bind_or_assign
(
self_attn
.
w_scale_v
=
ws_vc
.
contiguous
()
self_attn
.
w_scale_k
,
ws_kc
.
transpose
(
1
,
2
).
contiguous
()
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
()
)
self_attn
.
w_vc
=
w_vc
.
contiguous
()
self_attn
.
w_scale_v
=
bind_or_assign
(
self_attn
.
w_scale_v
,
ws_vc
.
contiguous
()
)
self_attn
.
w_kc
=
bind_or_assign
(
self_attn
.
w_kc
,
w_kc
.
transpose
(
1
,
2
).
contiguous
()
)
self_attn
.
w_vc
=
bind_or_assign
(
self_attn
.
w_vc
,
w_vc
.
contiguous
())
self_attn
.
use_deep_gemm_bmm
=
True
self_attn
.
use_deep_gemm_bmm
=
True
# TODO support nextn later
# TODO support nextn later
...
@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1958,7 +1980,10 @@ class DeepseekV2ForCausalLM(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
weight_names
=
[]
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
weight_names
.
append
(
name
)
if
not
is_nextn
:
if
not
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
...
@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -2075,7 +2100,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
)
self
.
post_load_weights
(
is_nextn
=
is_nextn
,
weight_names
=
weight_names
)
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
python/sglang/srt/utils.py
View file @
51cdd81f
...
@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
...
@@ -2217,3 +2217,11 @@ def read_system_prompt_from_file(model_name: str) -> str:
except
Exception
:
except
Exception
:
# If anything fails, return empty string
# If anything fails, return empty string
return
""
return
""
def
bind_or_assign
(
target
,
source
):
if
target
is
not
None
:
target
.
copy_
(
source
)
return
target
else
:
return
source
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment