Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a8853f7
Unverified
Commit
9a8853f7
authored
Feb 16, 2026
by
zhanqiuhu
Committed by
GitHub
Feb 16, 2026
Browse files
[Core] Pipeline Parallel support for Model Runner V2 (#33960)
Signed-off-by:
Zhanqiu Hu
<
zh338@cornell.edu
>
parent
387a1898
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
221 additions
and
17 deletions
+221
-17
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+102
-17
vllm/v1/worker/gpu/pp_handler.py
vllm/v1/worker/gpu/pp_handler.py
+119
-0
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
9a8853f7
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
gc
import
gc
import
time
import
time
from
copy
import
deepcopy
from
copy
import
deepcopy
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -11,11 +10,15 @@ import torch.nn as nn
...
@@ -11,11 +10,15 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
prepare_communication_buffer_for_model
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
prepare_communication_buffer_for_model
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
...
@@ -54,6 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.pp_handler
import
PPHandler
,
get_pp_handler
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
...
@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -178,6 +182,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
# KV Connector if configured.
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self
.
use_pp
=
self
.
parallel_config
.
pipeline_parallel_size
>
1
self
.
pp_handler
:
PPHandler
|
None
=
(
get_pp_handler
(
self
.
parallel_config
)
if
self
.
use_pp
else
None
)
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
...
@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -290,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
_dummy_run
(
def
_dummy_run
(
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
self
,
num_tokens
:
int
,
*
args
,
skip_attn
:
bool
=
True
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
|
None
,
torch
.
Tensor
|
None
]:
# Create a dummy scheduler output.
# Create a dummy scheduler output.
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
num_reqs
=
min
(
num_tokens
,
self
.
max_num_reqs
)
num_tokens_per_request
=
[
num_tokens
//
num_reqs
]
*
num_reqs
num_tokens_per_request
=
[
num_tokens
//
num_reqs
]
*
num_reqs
...
@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -306,13 +316,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Disable any use of KVConnector for dummy runs.
# Disable any use of KVConnector for dummy runs.
self
.
kv_connector
.
set_disabled
(
True
)
self
.
kv_connector
.
set_disabled
(
True
)
# For non-first PP ranks, create dummy intermediate_tensors.
intermediate_tensors
=
None
if
self
.
use_pp
and
not
get_pp_group
().
is_first_rank
:
intermediate_tensors
=
self
.
model
.
make_empty_intermediate_tensors
(
batch_size
=
num_tokens
,
dtype
=
self
.
model_config
.
dtype
,
device
=
self
.
device
,
)
# Execute the model.
# Execute the model.
self
.
execute_model
(
self
.
execute_model
(
dummy_scheduler_output
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
dummy_scheduler_output
,
intermediate_tensors
=
intermediate_tensors
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
,
)
)
self
.
kv_connector
.
set_disabled
(
False
)
self
.
kv_connector
.
set_disabled
(
False
)
# Non-last PP ranks don't produce output for sampling.
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
return
None
,
None
assert
self
.
execute_model_state
is
not
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
_
=
self
.
execute_model_state
hidden_states
,
input_batch
,
_
=
self
.
execute_model_state
assert
hidden_states
is
not
None
# Last PP rank always has hidden_states
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
return
hidden_states
,
sample_hidden_states
...
@@ -345,7 +373,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -345,7 +373,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
skip_attn
=
True
self
.
max_num_tokens
,
skip_attn
=
True
)
)
self
.
_dummy_sampler_run
(
sample_hidden_states
)
# Only run sampler on last PP rank (non-last ranks return None).
if
not
self
.
use_pp
or
get_pp_group
().
is_last_rank
:
assert
sample_hidden_states
is
not
None
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
do_spec_decode
:
if
self
.
do_spec_decode
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
num_tokens_across_dp
=
make_num_tokens_across_dp
(
self
.
parallel_config
.
data_parallel_size
,
self
.
max_num_tokens
self
.
parallel_config
.
data_parallel_size
,
self
.
max_num_tokens
...
@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -381,6 +412,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
return
0
return
0
# TODO (zhanqiu): support CUDA graph for PP.
if
self
.
use_pp
:
logger
.
warning_once
(
"Skipping CUDA graph capture because pipeline parallel is "
"enabled. Pipeline parallel is currently eager-only."
,
)
return
0
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -801,11 +840,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
execute_model
(
def
execute_model
(
self
,
self
,
scheduler_output
:
SchedulerOutput
,
scheduler_output
:
SchedulerOutput
,
intermediate_tensors
:
Any
|
None
=
None
,
intermediate_tensors
:
IntermediateTensors
|
None
=
None
,
dummy_run
:
bool
=
False
,
dummy_run
:
bool
=
False
,
skip_attn_for_dummy_run
:
bool
=
False
,
skip_attn_for_dummy_run
:
bool
=
False
,
)
->
ModelRunnerOutput
|
None
:
)
->
ModelRunnerOutput
|
IntermediateTensors
|
None
:
assert
intermediate_tensors
is
None
if
not
dummy_run
:
if
not
dummy_run
:
# Update the request states.
# Update the request states.
self
.
finish_requests
(
scheduler_output
)
self
.
finish_requests
(
scheduler_output
)
...
@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -851,8 +889,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
)
self
.
_set_active_loras
(
*
lora_inputs
)
self
.
_set_active_loras
(
*
lora_inputs
)
if
self
.
supports_mm_inputs
:
# Only first PP rank prepares multimodal embeddings.
# Execute the multimodal encoder.
if
self
.
supports_mm_inputs
and
(
not
self
.
use_pp
or
get_pp_group
().
is_first_rank
):
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
mm_embeds
,
is_mm_embed
=
self
.
get_mm_embeddings
(
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
scheduler_output
.
scheduled_encoder_inputs
,
input_batch
)
)
...
@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -894,6 +934,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
assert
input_batch
.
mrope_positions
is
not
None
assert
input_batch
.
mrope_positions
is
not
None
positions
=
input_batch
.
mrope_positions
positions
=
input_batch
.
mrope_positions
with
set_forward_context
(
with
set_forward_context
(
input_batch
.
attn_metadata
,
input_batch
.
attn_metadata
,
self
.
vllm_config
,
self
.
vllm_config
,
...
@@ -904,27 +945,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -904,27 +945,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping
=
input_batch
.
slot_mappings
,
slot_mapping
=
input_batch
.
slot_mappings
,
):
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
model
(
if
self
.
use_pp
and
not
get_pp_group
().
is_first_rank
:
input_ids
=
input_batch
.
input_ids
,
# Non-first PP rank: forward with intermediate tensors.
positions
=
positions
,
assert
intermediate_tensors
is
not
None
inputs_embeds
=
input_batch
.
inputs_embeds
,
hidden_states
=
self
.
model
(
)
input_ids
=
None
,
positions
=
positions
,
inputs_embeds
=
None
,
intermediate_tensors
=
intermediate_tensors
,
)
else
:
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
inputs_embeds
=
input_batch
.
inputs_embeds
,
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
kv_connector_output
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
# Non-last PP rank: return IntermediateTensors for sending.
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
hidden_states
.
kv_connector_output
=
kv_connector_output
self
.
execute_model_state
=
(
None
,
input_batch
,
kv_connector_output
)
return
hidden_states
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self
.
execute_model_state
=
(
hidden_states
,
input_batch
,
kv_connector_output
)
return
None
return
None
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
sample_tokens
(
def
sample_tokens
(
self
,
grammar_output
:
GrammarOutput
|
None
self
,
grammar_output
:
GrammarOutput
|
None
)
->
AsyncOutput
|
ModelRunnerOutput
:
)
->
AsyncOutput
|
ModelRunnerOutput
|
None
:
assert
self
.
execute_model_state
is
not
None
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
,
kv_connector_output
=
self
.
execute_model_state
hidden_states
,
input_batch
,
kv_connector_output
=
self
.
execute_model_state
self
.
execute_model_state
=
None
# type: ignore
self
.
execute_model_state
=
None
# type: ignore
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
if
self
.
use_pp
and
not
get_pp_group
().
is_last_rank
:
assert
self
.
pp_handler
is
not
None
received
=
self
.
pp_handler
.
maybe_receive_sampled_tokens
(
input_batch
.
num_reqs
,
self
.
device
,
max_sample_len
=
self
.
num_speculative_steps
+
1
,
)
if
received
is
not
None
:
sampled
,
num_sampled
,
num_rejected
=
received
self
.
postprocess
(
input_batch
,
sampled
,
num_sampled
,
num_rejected
)
return
None
# Last rank: sample tokens
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
hidden_states
,
input_batch
,
grammar_output
hidden_states
,
input_batch
,
grammar_output
)
)
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if
self
.
use_pp
:
assert
self
.
pp_handler
is
not
None
self
.
pp_handler
.
maybe_broadcast_sampled_tokens
(
sampler_output
,
num_sampled
,
num_rejected
)
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
self
.
model
.
compute_logits
,
self
.
model
.
compute_logits
,
hidden_states
,
hidden_states
,
...
...
vllm/v1/worker/gpu/pp_handler.py
0 → 100644
View file @
9a8853f7
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""
import
torch
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
class
PPHandler
:
"""Pipeline parallelism handler for Model Runner V2.
Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""
def
maybe_broadcast_sampled_tokens
(
self
,
sampler_output
:
SamplerOutput
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
None
:
"""Broadcast sampled tokens from the last PP rank to all other ranks.
No-ops if this is not the last rank.
Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.
Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp
=
get_pp_group
()
if
not
pp
.
is_last_rank
:
return
torch
.
distributed
.
broadcast
(
sampler_output
.
sampled_token_ids
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch
.
distributed
.
broadcast
(
num_sampled
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
torch
.
distributed
.
broadcast
(
num_rejected
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
def
maybe_receive_sampled_tokens
(
self
,
num_reqs
:
int
,
device
:
torch
.
device
,
max_sample_len
:
int
=
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
"""Receive sampled tokens broadcast by the last PP rank.
Returns None if this is the last rank (which samples, not receives).
Args:
num_reqs: Number of requests in the batch.
device: Device to create tensors on.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).
Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp
=
get_pp_group
()
if
pp
.
is_last_rank
:
return
None
sampled_tokens
=
torch
.
empty
(
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
device
)
torch
.
distributed
.
broadcast
(
sampled_tokens
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
distributed
.
broadcast
(
num_sampled
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
num_rejected
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
device
)
torch
.
distributed
.
broadcast
(
num_rejected
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
return
sampled_tokens
,
num_sampled
,
num_rejected
def
get_pp_handler
(
parallel_config
)
->
PPHandler
:
"""Factory function to create PPHandler.
Must only be called when PP is enabled (pp_size > 1).
"""
assert
parallel_config
.
pipeline_parallel_size
>
1
,
(
"PPHandler should not be created when pipeline parallelism is disabled."
)
return
PPHandler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment