Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18f030d9
Commit
18f030d9
authored
Jul 15, 2025
by
zhuwenwen
Browse files
修复在fp16下的数值越界导致的精度问题
parent
7bf6c98f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
38 deletions
+33
-38
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+33
-38
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
18f030d9
...
...
@@ -164,29 +164,24 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
# if hidden_states.dtype != torch.float16:
# final_hidden_states = self.experts(
# hidden_states=hidden_states,
# router_logits=router_logits) * self.routed_scaling_factor
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# final_hidden_states = self.experts(hidden_states=hidden_states,
# router_logits=router_logits)
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
# if shared_output is not None:
# if hidden_states.dtype != torch.float16:
# final_hidden_states = final_hidden_states + shared_output
# else:
# # Fix FP16 overflow
# # See DeepseekV2DecoderLayer for more details.
# final_hidden_states = final_hidden_states + shared_output \
# * (1. / self.routed_scaling_factor)
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
envs
.
VLLM_ENABLE_TBO
:
...
...
@@ -593,29 +588,29 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
=
hidden_states
,
)
#
if hidden_states.dtype == torch.float16:
#
# Fix FP16 overflow
#
# We scale both hidden_states and residual before
#
# rmsnorm, and rmsnorm result would not affect by scale.
#
hidden_states *= 1. / self.routed_scaling_factor
#
if self.layer_idx == 0:
#
# The residual is shared by all layers, we only scale it on
#
# first layer.
#
residual *= 1. / self.routed_scaling_factor
if
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
#
if isinstance(self.mlp,
#
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
#
# Fix FP16 overflow
#
# Scaling the DeepseekV2MLP output, it is the input of
#
# input_layernorm of next decoder layer.
#
# The scaling of DeepseekV2MOE output would be done in the forward
#
# of DeepseekV2MOE
#
hidden_states *= 1. / self.routed_scaling_factor
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
hidden_states
.
dtype
==
torch
.
float16
:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment