Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
04925b22
Unverified
Commit
04925b22
authored
Feb 16, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 16, 2026
Browse files
[Model Runner V2] Minor cleanup for PP (#34666)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
d74278fb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
58 deletions
+53
-58
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+47
-43
vllm/v1/worker/gpu/pp_handler.py
vllm/v1/worker/gpu/pp_handler.py
+6
-15
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
04925b22
...
...
@@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.pp_handler
import
PPHandler
,
get_pp_handler
from
vllm.v1.worker.gpu.pp_handler
import
PPHandler
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
...
...
@@ -184,9 +184,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Pipeline parallelism.
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
pp_handler
:
PPHandler
|
None
=
(
get_pp_handler
(
self
.
parallel_config
)
if
self
.
use_pp
else
None
)
if
self
.
use_pp
:
self
.
is_first_pp_rank
=
get_pp_group
().
is_first_rank
self
.
is_last_pp_rank
=
get_pp_group
().
is_last_rank
self
.
pp_handler
:
PPHandler
|
None
=
PPHandler
(
self
.
device
)
else
:
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
self
.
pp_handler
=
None
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
...
...
@@ -318,7 +323,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors
=
None
if
self
.
use_pp
and
not
get_pp_group
()
.
is_first_rank
:
if
not
self
.
is_first_
pp_
rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
num_tokens
,
dtype
=
self
.
model_config
.
dtype
,
...
...
@@ -335,7 +340,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_connector
.
set_disabled
(
False
)
# Non-last PP ranks don't produce output for sampling.
if
self
.
use_pp
and
not
get_pp_group
()
.
is_last_rank
:
if
not
self
.
is_last_
pp_
rank
:
return
None
,
None
assert
self
.
execute_model_state
is
not
None
...
...
@@ -373,20 +378,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
skip_attn
=
True
)
# Only run sampler on last PP rank (non-last ranks return None).
if
not
self
.
use_pp
or
get_pp_group
().
is_last_rank
:
if
self
.
is_last_
pp_
rank
:
assert
sample_hidden_states
is
not
None
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
parallel_config
.
data_parallel_size
,
self
.
max_num_tokens
)
self
.
speculator
.
run_model
(
self
.
max_num_tokens
,
attn_metadata
=
None
,
slot_mappings
=
None
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
if
self
.
do_spec_decode
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
parallel_config
.
data_parallel_size
,
self
.
max_num_tokens
)
self
.
speculator
.
run_model
(
self
.
max_num_tokens
,
attn_metadata
=
None
,
slot_mappings
=
None
,
num_tokens_across_dp
=
num_tokens_across_dp
,
)
torch
.
cuda
.
synchronize
()
del
hidden_states
,
sample_hidden_states
gc
.
collect
()
...
...
@@ -890,9 +898,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
_set_active_loras
(
*
lora_inputs
)
# Only first PP rank prepares multimodal embeddings.
if
self
.
supports_mm_inputs
and
(
not
self
.
use_pp
or
get_pp_group
().
is_first_rank
):
if
self
.
supports_mm_inputs
and
self
.
is_first_pp_rank
:
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
)
...
...
@@ -935,6 +941,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert
input_batch
.
mrope_positions
is
not
None
positions
=
input_batch
.
mrope_positions
if
self
.
is_first_pp_rank
:
input_ids
=
input_batch
.
input_ids
inputs_embeds
=
input_batch
.
inputs_embeds
assert
intermediate_tensors
is
None
else
:
input_ids
=
None
inputs_embeds
=
None
assert
intermediate_tensors
is
not
None
with
set_forward_context
(
input_batch
.
attn_metadata
,
self
.
vllm_config
,
...
...
@@ -945,25 +960,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping
=
input_batch
.
slot_mappings
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
if
self
.
use_pp
and
not
get_pp_group
().
is_first_rank
:
# Non-first PP rank: forward with intermediate tensors.
assert
intermediate_tensors
is
not
None
hidden_states
=
self
.
model
(
input_ids
=
None
,
positions
=
positions
,
inputs_embeds
=
None
,
intermediate_tensors
=
intermediate_tensors
,
)
else
:
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
inputs_embeds
=
input_batch
.
inputs_embeds
,
)
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
inputs_embeds
=
inputs_embeds
,
intermediate_tensors
=
intermediate_tensors
,
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
if
self
.
use_pp
and
not
get_pp_group
()
.
is_last_rank
:
if
not
self
.
is_last_
pp_
rank
:
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
...
...
@@ -986,16 +992,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if
self
.
use_pp
and
not
get_pp_group
()
.
is_last_rank
:
if
not
self
.
is_last_
pp_
rank
:
assert
self
.
pp_handler
is
not
None
received
=
self
.
pp_handler
.
maybe_receive_sampled_tokens
(
input_batch
.
num_reqs
,
self
.
device
,
max_sample_len
=
self
.
num_speculative_steps
+
1
,
input_batch
.
num_reqs
,
max_sample_len
=
self
.
num_speculative_steps
+
1
)
if
received
is
not
None
:
sampled
,
num_sampled
,
num_rejected
=
received
self
.
postprocess
(
input_batch
,
sampled
,
num_sampled
,
num_rejected
)
assert
received
is
not
None
sampled
,
num_sampled
,
num_rejected
=
received
self
.
postprocess
(
input_batch
,
sampled
,
num_sampled
,
num_rejected
)
return
None
# Last rank: sample tokens
...
...
vllm/v1/worker/gpu/pp_handler.py
View file @
04925b22
...
...
@@ -15,6 +15,9 @@ class PPHandler:
Only instantiated when PP is enabled (pp_size > 1).
"""
def
__init__
(
self
,
device
:
torch
.
device
):
self
.
device
=
device
def
maybe_broadcast_sampled_tokens
(
self
,
sampler_output
:
SamplerOutput
,
...
...
@@ -59,7 +62,6 @@ class PPHandler:
def
maybe_receive_sampled_tokens
(
self
,
num_reqs
:
int
,
device
:
torch
.
device
,
max_sample_len
:
int
=
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
"""Receive sampled tokens broadcast by the last PP rank.
...
...
@@ -84,7 +86,7 @@ class PPHandler:
return
None
sampled_tokens
=
torch
.
empty
(
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
device
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
sampled_tokens
,
...
...
@@ -93,27 +95,16 @@ class PPHandler:
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
num_sampled
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
num_rejected
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
num_rejected
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
num_rejected
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
return
sampled_tokens
,
num_sampled
,
num_rejected
def
get_pp_handler
(
parallel_config
)
->
PPHandler
:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert
parallel_config
.
pipeline_parallel_size
>
1
,
(
"PPHandler should not be created when pipeline parallelism is disabled."
)
return
PPHandler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment