Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c982ac57
Unverified
Commit
c982ac57
authored
Mar 10, 2025
by
Concurrensee
Committed by
GitHub
Mar 10, 2025
Browse files
[Bugfix] Fix FP16 overflow for DeepSeek V2 (#13232)
Signed-off-by:
Yida Wu
<
yida.wu@amd.com
>
parent
4290b704
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
4 deletions
+24
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+24
-4
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
c982ac57
...
@@ -155,11 +155,21 @@ class DeepseekV2MoE(nn.Module):
...
@@ -155,11 +155,21 @@ class DeepseekV2MoE(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
self
.
experts
(
if
hidden_states
.
dtype
!=
torch
.
float16
:
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
*
self
.
routed_scaling_factor
else
:
# This is a special case to avoid FP16 overflow
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
if
hidden_states
.
dtype
!=
torch
.
float16
:
final_hidden_states
=
final_hidden_states
+
shared_output
else
:
# This is a special case to avoid FP16 overflow
final_hidden_states
=
final_hidden_states
+
shared_output
\
*
(
1.
/
self
.
routed_scaling_factor
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
)
...
@@ -531,6 +541,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -531,6 +541,7 @@ class DeepseekV2DecoderLayer(nn.Module):
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
def
forward
(
def
forward
(
self
,
self
,
...
@@ -551,9 +562,18 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -551,9 +562,18 @@ class DeepseekV2DecoderLayer(nn.Module):
)
)
# Fully Connected
# Fully Connected
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
# This is a special case to avoid FP16 overflow
hidden_states
*=
1.
/
self
.
routed_scaling_factor
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
isinstance
(
self
.
mlp
,
DeepseekV2MLP
)
and
\
hidden_states
.
dtype
==
torch
.
float16
:
# This is a special case to avoid FP16 overflow
hidden_states
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
return
hidden_states
,
residual
return
hidden_states
,
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment