Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a8853f7
Unverified
Commit
9a8853f7
authored
Feb 16, 2026
by
zhanqiuhu
Committed by
GitHub
Feb 16, 2026
Browse files
[Core] Pipeline Parallel support for Model Runner V2 (#33960)
Signed-off-by:
Zhanqiu Hu
<
zh338@cornell.edu
>
parent
387a1898
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
221 additions
and
17 deletions
+221
-17
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+102
-17
vllm/v1/worker/gpu/pp_handler.py
vllm/v1/worker/gpu/pp_handler.py
+119
-0
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
9a8853f7
...
...
@@ -3,7 +3,6 @@
import
gc
import
time
from
copy
import
deepcopy
from
typing
import
Any
import
numpy
as
np
import
torch
...
...
@@ -11,11 +10,15 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
prepare_communication_buffer_for_model
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
prepare_communication_buffer_for_model
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
...
@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.pp_handler
import
PPHandler
,
get_pp_handler
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
...
...
@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
pp_handler
:
PPHandler
|
None
=
(
get_pp_handler
(
self
.
parallel_config
)
if
self
.
use_pp
else
None
)
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
...
...
@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@
torch
.
inference_mode
()
def
_dummy_run
(
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
]:
# Create a dummy scheduler output.
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
num_tokens_per_request
=
[
num_tokens
//
num_reqs
]
*
num_reqs
...
...
@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs.
self
.
kv_connector
.
set_disabled
(
True
)
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors
=
None
if
self
.
use_pp
and
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
num_tokens
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
# Execute the model.
self
.
execute_model
(
dummy_scheduler_output
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
dummy_scheduler_output
,
intermediate_tensors
=
intermediate_tensors
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
,
)
self
.
kv_connector
.
set_disabled
(
False
)
# Non-last PP ranks don't produce output for sampling.
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
return
None
,
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
_
=
self
.
execute_model_state
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
...
...
@@ -345,6 +373,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
skip_attn
=
True
)
# Only run sampler on last PP rank (non-last ranks return None).
if
not
self
.
use_pp
or
get_pp_group
().
is_last_rank
:
assert
sample_hidden_states
is
not
None
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
...
...
@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
return
0
# TODO (zhanqiu): support CUDA graph for PP.
if
self
.
use_pp
:
logger
.
warning_once
(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only."
,
)
return
0
start_time
=
time
.
perf_counter
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
execute_model
(
self
,
scheduler_output
:
SchedulerOutput
,
intermediate_tensors
:
Any
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
dummy_run
:
bool
=
False
,
skip_attn_for_dummy_run
:
bool
=
False
,
)
->
ModelRunnerOutput
|
None
:
assert
intermediate_tensors
is
None
)
->
ModelRunnerOutput
|
IntermediateTensors
|
None
:
if
not
dummy_run
:
# Update the request states.
self
.
finish_requests
(
scheduler_output
)
...
...
@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self
.
_set_active_loras
(
*
lora_inputs
)
if
self
.
supports_mm_inputs
:
# Execute the multimodal encoder.
# Only first PP rank prepares multimodal embeddings.
if
self
.
supports_mm_inputs
and
(
not
self
.
use_pp
or
get_pp_group
().
is_first_rank
):
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
)
...
...
@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
uses_mrope
:
assert
input_batch
.
mrope_positions
is
not
None
positions
=
input_batch
.
mrope_positions
with
set_forward_context
(
input_batch
.
attn_metadata
,
self
.
vllm_config
,
...
...
@@ -904,6 +945,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping
=
input_batch
.
slot_mappings
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
if
self
.
use_pp
and
not
get_pp_group
().
is_first_rank
:
# Non-first PP rank: forward with intermediate tensors.
assert
intermediate_tensors
is
not
None
hidden_states
=
self
.
model
(
input_ids
=
None
,
positions
=
positions
,
inputs_embeds
=
None
,
intermediate_tensors
=
intermediate_tensors
,
)
else
:
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
...
...
@@ -911,20 +962,54 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
kv_connector_output
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
execute_model_state
=
(
None
,
input_batch
,
kv_connector_output
)
return
hidden_states
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self
.
execute_model_state
=
(
hidden_states
,
input_batch
,
kv_connector_output
)
return
None
@
torch
.
inference_mode
()
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
)
->
AsyncOutput
|
ModelRunnerOutput
:
)
->
AsyncOutput
|
ModelRunnerOutput
|
None
:
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
kv_connector_output
=
self
.
execute_model_state
self
.
execute_model_state
=
None
# type: ignore
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
assert
self
.
pp_handler
is
not
None
received
=
self
.
pp_handler
.
maybe_receive_sampled_tokens
(
input_batch
.
num_reqs
,
self
.
device
,
max_sample_len
=
self
.
num_speculative_steps
+
1
,
)
if
received
is
not
None
:
sampled
,
num_sampled
,
num_rejected
=
received
self
.
postprocess
(
input_batch
,
sampled
,
num_sampled
,
num_rejected
)
return
None
# Last rank: sample tokens
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
grammar_output
)
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if
self
.
use_pp
:
assert
self
.
pp_handler
is
not
None
self
.
pp_handler
.
maybe_broadcast_sampled_tokens
(
sampler_output
,
num_sampled
,
num_rejected
)
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
self
.
model
.
compute_logits
,
hidden_states
,
...
...
vllm/v1/worker/gpu/pp_handler.py
0 → 100644
View file @
9a8853f7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""
import
torch
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
class
PPHandler
:
"""Pipeline parallelism handler for Model Runner V2.
Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""
def
maybe_broadcast_sampled_tokens
(
self
,
sampler_output
:
SamplerOutput
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
None
:
"""Broadcast sampled tokens from the last PP rank to all other ranks.
No-ops if this is not the last rank.
Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.
Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp
=
get_pp_group
()
if
not
pp
.
is_last_rank
:
return
torch
.
distributed
.
broadcast
(
sampler_output
.
sampled_token_ids
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch
.
distributed
.
broadcast
(
num_sampled
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
torch
.
distributed
.
broadcast
(
num_rejected
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
def
maybe_receive_sampled_tokens
(
self
,
num_reqs
:
int
,
device
:
torch
.
device
,
max_sample_len
:
int
=
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
"""Receive sampled tokens broadcast by the last PP rank.
Returns None if this is the last rank (which samples, not receives).
Args:
num_reqs: Number of requests in the batch.
device: Device to create tensors on.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).
Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp
=
get_pp_group
()
if
pp
.
is_last_rank
:
return
None
sampled_tokens
=
torch
.
empty
(
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
device
)
torch
.
distributed
.
broadcast
(
sampled_tokens
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
distributed
.
broadcast
(
num_sampled
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
num_rejected
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
distributed
.
broadcast
(
num_rejected
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
return
sampled_tokens
,
num_sampled
,
num_rejected
def
get_pp_handler
(
parallel_config
)
->
PPHandler
:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert
parallel_config
.
pipeline_parallel_size
>
1
,
(
"PPHandler should not be created when pipeline parallelism is disabled."
)
return
PPHandler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment