Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
be3af2d2
Unverified
Commit
be3af2d2
authored
Feb 17, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 17, 2026
Browse files
[Model Runner V2] Further simplification for PP (#34724)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
c656ba3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
118 deletions
+46
-118
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+3
-9
vllm/v1/worker/gpu/pp_handler.py
vllm/v1/worker/gpu/pp_handler.py
+0
-109
vllm/v1/worker/gpu/pp_utils.py
vllm/v1/worker/gpu/pp_utils.py
+43
-0
No files found.
vllm/v1/worker/gpu/model_runner.py
View file @
be3af2d2
...
@@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
...
@@ -57,7 +57,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.pp_
handler
import
PPHandler
from
vllm.v1.worker.gpu.pp_
utils
import
pp_broadcast
,
pp_receive
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
from
vllm.v1.worker.gpu.sample.sampler
import
Sampler
...
@@ -185,11 +185,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -185,11 +185,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
self
.
use_pp
:
if
self
.
use_pp
:
self
.
is_first_pp_rank
=
get_pp_group
().
is_first_rank
self
.
is_first_pp_rank
=
get_pp_group
().
is_first_rank
self
.
is_last_pp_rank
=
get_pp_group
().
is_last_rank
self
.
is_last_pp_rank
=
get_pp_group
().
is_last_rank
self
.
pp_handler
:
PPHandler
|
None
=
PPHandler
(
self
.
device
)
else
:
else
:
self
.
is_first_pp_rank
=
True
self
.
is_first_pp_rank
=
True
self
.
is_last_pp_rank
=
True
self
.
is_last_pp_rank
=
True
self
.
pp_handler
=
None
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
...
@@ -987,8 +985,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -987,8 +985,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# IntermediateTensors instead of final hidden states. Receive the
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast by the last rank and update local state.
# sampled tokens broadcast by the last rank and update local state.
if
not
self
.
is_last_pp_rank
:
if
not
self
.
is_last_pp_rank
:
assert
self
.
pp_handler
is
not
None
received
=
pp_receive
(
received
=
self
.
pp_handler
.
maybe_receive_sampled_tokens
(
input_batch
.
num_reqs
,
max_sample_len
=
self
.
num_speculative_steps
+
1
input_batch
.
num_reqs
,
max_sample_len
=
self
.
num_speculative_steps
+
1
)
)
assert
received
is
not
None
assert
received
is
not
None
...
@@ -1003,10 +1000,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1003,10 +1000,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Broadcast to non-last PP ranks (handles spec decode multi-token).
# Broadcast to non-last PP ranks (handles spec decode multi-token).
if
self
.
use_pp
:
if
self
.
use_pp
:
assert
self
.
pp_handler
is
not
None
pp_broadcast
(
sampler_output
.
sampled_token_ids
,
num_sampled
,
num_rejected
)
self
.
pp_handler
.
maybe_broadcast_sampled_tokens
(
sampler_output
,
num_sampled
,
num_rejected
)
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
self
.
model
.
compute_logits
,
self
.
model
.
compute_logits
,
...
...
vllm/v1/worker/gpu/pp_handler.py
deleted
100644 → 0
View file @
c656ba3b
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism handler for V2 Model Runner."""
import
torch
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
class
PPHandler
:
"""Pipeline parallelism handler for Model Runner V2.
Manages sampled token synchronization between PP ranks.
Only instantiated when PP is enabled (pp_size > 1).
"""
def
__init__
(
self
,
device
:
torch
.
device
):
self
.
device
=
device
def
maybe_broadcast_sampled_tokens
(
self
,
sampler_output
:
SamplerOutput
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
None
:
"""Broadcast sampled tokens from the last PP rank to all other ranks.
No-ops if this is not the last rank.
Broadcasts sampled_token_ids [num_reqs, max_sample_len], num_sampled
[num_reqs], and num_rejected [num_reqs] to support both regular decode
and speculative decoding.
Args:
sampler_output: SamplerOutput from sampling.
num_sampled: Number of accepted tokens per request.
num_rejected: Number of rejected tokens per request.
"""
pp
=
get_pp_group
()
if
not
pp
.
is_last_rank
:
return
torch
.
distributed
.
broadcast
(
sampler_output
.
sampled_token_ids
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
torch
.
distributed
.
broadcast
(
num_sampled
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
torch
.
distributed
.
broadcast
(
num_rejected
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
def
maybe_receive_sampled_tokens
(
self
,
num_reqs
:
int
,
max_sample_len
:
int
=
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
"""Receive sampled tokens broadcast by the last PP rank.
Returns None if this is the last rank (which samples, not receives).
Args:
num_reqs: Number of requests in the batch.
max_sample_len: Maximum number of tokens sampled per request
(1 for regular decode, >1 for speculative decoding).
Returns:
None if called on last rank.
Otherwise, tuple of (sampled_tokens, num_sampled, num_rejected):
- sampled_tokens: shape [num_reqs, max_sample_len]
- num_sampled: shape [num_reqs]
- num_rejected: shape [num_reqs]
"""
pp
=
get_pp_group
()
if
pp
.
is_last_rank
:
return
None
sampled_tokens
=
torch
.
empty
(
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
sampled_tokens
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
# NOTE: num_sampled/num_rejected are only needed
# for speculative decoding.
num_sampled
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
num_sampled
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
num_rejected
=
torch
.
empty
(
num_reqs
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
num_rejected
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
,
)
return
sampled_tokens
,
num_sampled
,
num_rejected
vllm/v1/worker/gpu/pp_utils.py
0 → 100644
View file @
be3af2d2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Pipeline Parallelism utils for V2 Model Runner."""
import
torch
from
vllm.distributed.parallel_state
import
get_pp_group
def
pp_broadcast
(
sampled_token_ids
:
torch
.
Tensor
,
num_sampled
:
torch
.
Tensor
,
num_rejected
:
torch
.
Tensor
,
)
->
None
:
pp
=
get_pp_group
()
if
not
pp
.
is_last_rank
:
return
assert
sampled_token_ids
.
dtype
==
torch
.
int64
torch
.
distributed
.
broadcast
(
sampled_token_ids
.
contiguous
(),
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
combined
=
torch
.
stack
((
num_sampled
,
num_rejected
),
dim
=
0
)
torch
.
distributed
.
broadcast
(
combined
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
def
pp_receive
(
num_reqs
:
int
,
max_sample_len
:
int
=
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]
|
None
:
pp
=
get_pp_group
()
if
pp
.
is_last_rank
:
return
None
sampled_tokens
=
torch
.
empty
(
num_reqs
,
max_sample_len
,
dtype
=
torch
.
int64
,
device
=
pp
.
device
)
torch
.
distributed
.
broadcast
(
sampled_tokens
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
combined
=
torch
.
empty
(
2
,
num_reqs
,
dtype
=
torch
.
int32
,
device
=
pp
.
device
)
torch
.
distributed
.
broadcast
(
combined
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
num_sampled
,
num_rejected
=
combined
.
unbind
(
dim
=
0
)
return
sampled_tokens
,
num_sampled
,
num_rejected
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment