Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ac4cc84e
Commit
ac4cc84e
authored
Aug 18, 2025
by
zhuwenwen
Browse files
[fix]修复mtp eager模式下显存占用增加问题
parent
43e67933
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
13 deletions
+8
-13
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+8
-13
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
ac4cc84e
...
...
@@ -59,6 +59,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
...
...
@@ -76,6 +81,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
...
...
@@ -112,10 +119,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
)
})
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
forward
(
...
...
@@ -126,8 +130,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
(
spec_step_idx
%
self
.
num_mtp_layers
)
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
...
...
@@ -315,20 +317,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
spec_layer_weight_names
=
[
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
]
shared_weight_names
=
[
"embed_tokens"
]
spec_layer_weight
=
False
shared_weight
=
False
for
weight_name
in
spec_layer_weight_names
:
if
weight_name
in
name
:
spec_layer_weight
=
True
if
weight_name
in
shared_weight_names
:
shared_weight
=
True
break
if
not
spec_layer_weight
:
# treat rest weights as weights for transformer layer block
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
elif
shared_weight
:
# treat shared weights as top level weights
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
"model."
)
return
name
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment