Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e903b6c
Unverified
Commit
3e903b6c
authored
Sep 13, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 13, 2025
Browse files
[Chore] Minor simplification for non-PP path (#24810)
Signed-off-by:
Woosuk Kwon
<
woosuk@thinkingmachines.ai
>
parent
973c9d01
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
25 deletions
+39
-25
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+39
-25
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
3e903b6c
...
...
@@ -86,7 +86,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from
vllm.v1.utils
import
CpuGpuBuffer
,
record_function_or_nullcontext
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
(
KVConnectorModelRunnerMixin
,
KVConnectorOutput
)
KVConnectorModelRunnerMixin
)
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
.utils
import
(
AttentionGroup
,
MultiModalBudget
,
...
...
@@ -196,6 +196,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_reqs
=
scheduler_config
.
max_num_seqs
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
self
.
broadcast_pp_output
=
(
self
.
parallel_config
.
distributed_executor_backend
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
)
# Model-related.
self
.
num_query_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
)
...
...
@@ -1701,7 +1709,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states
:
torch
.
Tensor
,
num_scheduled_tokens
:
int
,
num_scheduled_tokens_np
:
np
.
ndarray
,
kv_connector_output
:
Optional
[
KVConnectorOutput
],
)
->
ModelRunnerOutput
:
assert
self
.
input_batch
.
num_reqs
==
\
len
(
self
.
input_batch
.
pooling_params
),
\
...
...
@@ -1732,7 +1739,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
pooler_output
,
kv_connector_output
=
kv_connector_output
,
)
def
_preprocess
(
...
...
@@ -2073,39 +2079,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with
record_function_or_nullcontext
(
"Postprocess"
):
if
self
.
use_aux_hidden_state_outputs
:
# True when EAGLE 3 is used.
hidden_states
,
aux_hidden_states
=
model_output
else
:
# Common case.
hidden_states
=
model_output
aux_hidden_states
=
None
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output
=
\
self
.
parallel_config
.
distributed_executor_backend
\
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
if
not
broadcast_pp_output
:
if
not
self
.
broadcast_pp_output
:
# Common case.
if
not
get_pp_group
().
is_last_rank
:
# Return the intermediate tensors.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
return
hidden_states
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
if
self
.
is_pooling_model
:
return
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
,
kv_connector_output
)
# Return the pooling output.
output
=
self
.
_pool
(
hidden_states
,
num_scheduled_tokens
,
num_scheduled_tokens_np
)
output
.
kv_connector_output
=
kv_connector_output
return
output
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
broadcast_pp_output
:
model_output_broadcast_data
=
{
"logits"
:
logits
.
contiguous
(),
}
if
logits
is
not
None
else
{}
else
:
# Rare case.
assert
not
self
.
is_pooling_model
if
not
get_pp_group
().
is_last_rank
:
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
model_output_broadcast_data
=
{}
if
logits
is
not
None
:
model_output_broadcast_data
[
"logits"
]
=
logits
.
contiguous
()
model_output_broadcast_data
=
get_pp_group
(
).
broadcast_tensor_dict
(
model_output_broadcast_data
,
src
=
len
(
get_pp_group
().
ranks
)
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment