Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7ed77d6b
Unverified
Commit
7ed77d6b
authored
Apr 05, 2025
by
inkcherry
Committed by
GitHub
Apr 04, 2025
Browse files
fix dummy-load deepseekv2 (#4535)
parent
4c54f442
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
73 deletions
+87
-73
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+8
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+79
-73
No files found.
python/sglang/srt/model_loader/loader.py
View file @
7ed77d6b
...
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
...
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
initialize_dummy_weights
(
model
)
initialize_dummy_weights
(
model
)
# Model weight loading consists of two stages:
# 1. Initial weight loading.
# 2. Post-processing of weights, including assigning specific member variables.
# For `dummy_init`, only the second stage is required.
if
hasattr
(
model
,
"post_load_weights"
):
model
.
post_load_weights
()
return
model
.
eval
()
return
model
.
eval
()
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
7ed77d6b
...
@@ -1380,6 +1380,84 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1380,6 +1380,84 @@ class DeepseekV2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
)
)
def
post_load_weights
(
self
):
# Perform post-processing after loading weights
if
not
global_server_args_dict
[
"disable_mla"
]:
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
).
T
else
:
w
=
ops
.
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
0
,
0
,
0
,
).
T
else
:
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
if
w
.
dtype
==
torch
.
int8
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
# block-wise int8 need it
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
=
int8_block_dequant
(
weight
,
weight_scale
,
weight_block_size
).
to
(
torch
.
bfloat16
)
else
:
# channel-wise int8 need it
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
...
@@ -1504,79 +1582,7 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1504,79 +1582,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
if
not
global_server_args_dict
[
"disable_mla"
]:
self
.
post_load_weights
()
for
layer_id
in
range
(
self
.
config
.
num_hidden_layers
):
self_attn
=
self
.
model
.
layers
[
layer_id
].
self_attn
if
hasattr
(
self_attn
.
kv_b_proj
,
"qweight"
):
# AWQ compatible
if
_is_cuda
:
w
=
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
).
T
else
:
w
=
ops
.
awq_dequantize
(
self_attn
.
kv_b_proj
.
qweight
,
self_attn
.
kv_b_proj
.
scales
,
self_attn
.
kv_b_proj
.
qzeros
,
0
,
0
,
0
,
).
T
else
:
w
=
self_attn
.
kv_b_proj
.
weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
)
and
w
.
dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e4m3fnuz
,
):
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
w
,
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
,
input_scale
=
None
,
)
else
:
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
,
scale
=
block_quant_to_tensor_quant
(
weight
,
weight_scale
,
weight_block_size
)
self_attn
.
w_scale
=
scale
if
w
.
dtype
==
torch
.
int8
:
if
hasattr
(
self
.
quant_config
,
"weight_block_size"
):
# block-wise int8 need it
weight_block_size
=
self
.
quant_config
.
weight_block_size
if
weight_block_size
is
not
None
:
assert
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale_inv"
)
weight
=
w
weight_scale
=
self_attn
.
kv_b_proj
.
weight_scale_inv
w
=
int8_block_dequant
(
weight
,
weight_scale
,
weight_block_size
).
to
(
torch
.
bfloat16
)
else
:
# channel-wise int8 need it
w
=
w
.
to
(
torch
.
bfloat16
)
*
self_attn
.
kv_b_proj
.
weight_scale
.
to
(
torch
.
bfloat16
)
w_kc
,
w_vc
=
w
.
unflatten
(
0
,
(
-
1
,
self_attn
.
qk_nope_head_dim
+
self_attn
.
v_head_dim
)
).
split
([
self_attn
.
qk_nope_head_dim
,
self_attn
.
v_head_dim
],
dim
=
1
)
self_attn
.
w_kc
=
w_kc
.
transpose
(
1
,
2
).
contiguous
().
transpose
(
1
,
2
)
self_attn
.
w_vc
=
w_vc
.
contiguous
().
transpose
(
1
,
2
)
if
(
hasattr
(
self_attn
.
kv_b_proj
,
"weight_scale"
)
and
self_attn
.
w_scale
is
None
):
self_attn
.
w_scale
=
self_attn
.
kv_b_proj
.
weight_scale
if
_is_hip
:
self_attn
.
w_scale
*=
2.0
def
get_embed_and_head
(
self
):
def
get_embed_and_head
(
self
):
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
return
self
.
model
.
embed_tokens
.
weight
,
self
.
lm_head
.
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment