Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e60a51b4
Commit
e60a51b4
authored
Dec 29, 2025
by
yangql
Browse files
增加mtp的pad
parent
a0be38cb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
14 deletions
+65
-14
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+59
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+4
-1
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+2
-2
No files found.
vllm/v1/spec_decode/eagle.py
View file @
e60a51b4
...
...
@@ -10,8 +10,9 @@ from vllm.attention.layer import Attention
from
vllm.config
import
(
CompilationLevel
,
VllmConfig
,
get_layers_from_vllm_config
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
DPMetadata
,
set_forward_context
from
vllm.logger
import
init_logger
import
vllm.envs
as
envs
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.models
import
supports_multimodal
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
...
...
@@ -185,6 +186,9 @@ class EagleProposer:
else
:
num_input_tokens
=
num_tokens
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_input_tokens
)
num_input_tokens
+=
num_pad
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
...
...
@@ -224,7 +228,8 @@ class EagleProposer:
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
num_tokens_across_dp
=
num_tokens_across_dp
):
#skip_cuda_graphs=not decoding):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
...
@@ -369,7 +374,8 @@ class EagleProposer:
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
num_tokens
=
input_batch_size
,
num_tokens_across_dp
=
num_tokens_across_dp
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
...
...
@@ -496,6 +502,40 @@ class EagleProposer:
logger
.
info
(
"Loading EAGLE LM head weights from the target model."
)
self
.
model
.
lm_head
=
target_language_model
.
lm_head
def
get_dp_padding
(
self
,
num_tokens
:
int
)
->
tuple
[
int
,
Optional
[
torch
.
Tensor
]]:
dp_size
=
self
.
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
self
.
vllm_config
.
parallel_config
.
data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use CUDA graphs (enabled by this padding) on the decoder.
#
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if
dp_size
==
1
or
self
.
vllm_config
.
model_config
.
enforce_eager
or
envs
.
VLLM_ALL2ALL_BACKEND
!=
'naive'
:
# auto
if
not
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_auto"
:
# Early exit.
return
0
,
None
try
:
num_tokens_across_dp
=
DPMetadata
.
num_tokens_across_dp
(
num_tokens
,
dp_size
,
dp_rank
)
max_tokens_across_dp_cpu
=
torch
.
max
(
num_tokens_across_dp
).
item
()
num_tokens_after_padding
=
torch
.
tensor
([
max_tokens_across_dp_cpu
]
*
dp_size
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
return
max_tokens_across_dp_cpu
-
num_tokens
,
num_tokens_after_padding
except
(
RuntimeError
,
AttributeError
)
as
e
:
# DP group may not be initialized yet during dummy run
# Skip padding in this case
logger
.
debug
(
"Skipping DP padding in eagle get_dp_padding due to: %s"
,
e
)
return
0
,
None
@
torch
.
inference_mode
()
def
dummy_run
(
self
,
...
...
@@ -505,24 +545,32 @@ class EagleProposer:
if
attn_metadata
is
not
None
and
self
.
attn_metadata_cudagraph
is
None
:
self
.
attn_metadata_cudagraph
=
attn_metadata
[
self
.
attn_layer_names
[
0
]]
# Padding for DP
num_input_tokens
=
num_tokens
num_pad
,
num_tokens_across_dp
=
self
.
get_dp_padding
(
num_tokens
)
num_input_tokens
+=
num_pad
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
self
.
input_ids
[:
num_
input_
tokens
],
self
.
positions
[:
num_
input_
tokens
],
self
.
hidden_states
[:
num_
input_
tokens
],
)
if
self
.
dp_size
>
1
and
self
.
enable_expert_parallel
and
self
.
num_speculative_tokens
>
1
:
for
_
in
range
(
self
.
num_speculative_tokens
-
1
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
self
.
model
(
self
.
input_ids
[:
num_tokens
],
self
.
positions
[:
num_tokens
],
self
.
hidden_states
[:
num_tokens
],
self
.
input_ids
[:
num_
input_
tokens
],
self
.
positions
[:
num_
input_
tokens
],
self
.
hidden_states
[:
num_
input_
tokens
],
)
def
validate_same_kv_cache_group
(
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
e60a51b4
...
...
@@ -2081,7 +2081,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
if
not
self
.
ep_sp
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
else
:
if
self
.
speculative_config
is
not
None
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
min_tokens_per_req
else
:
num_scheduled_tokens_list
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_scheduled_tokens_list
)
==
num_tokens
assert
len
(
num_scheduled_tokens_list
)
==
num_reqs
...
...
vllm/zero_overhead/v1/eagle.py
View file @
e60a51b4
...
...
@@ -146,8 +146,8 @@ class V1ZeroEagleProposer(EagleProposer):
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
num_tokens
=
num_input_tokens
,
):
#
skip_cuda_graphs=not decoding):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment