Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0369a0a6
Commit
0369a0a6
authored
Aug 29, 2025
by
zhuwenwen
Browse files
修复residual FP16 overflow,解决mtp采样率和数据集精度的冲突
parent
ab1ea369
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
2 deletions
+6
-2
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-2
No files found.
vllm/model_executor/models/deepseek_v2.py
View file @
0369a0a6
...
@@ -66,7 +66,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
...
@@ -66,7 +66,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'
1
'
)
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'
0
'
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -621,9 +621,13 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -621,9 +621,13 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
if
residual
is
None
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
...
@@ -637,7 +641,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -637,7 +641,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# We scale both hidden_states and residual before
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# The residual is shared by all layers, we only scale it on
# first layer.
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
residual
*=
1.
/
self
.
routed_scaling_factor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment