Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2f8b4ce0
Unverified
Commit
2f8b4ce0
authored
Mar 11, 2026
by
Woosuk Kwon
Committed by
GitHub
Mar 12, 2026
Browse files
[Model Runner V2] Do not initialize sampler for non-last PP ranks (#36824)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
2ef69456
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
50 deletions
+75
-50
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+11
-8
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+60
-39
vllm/v1/worker/gpu/pool/pooling_runner.py
vllm/v1/worker/gpu/pool/pooling_runner.py
+4
-3
No files found.
vllm/v1/worker/gpu/input_batch.py
View file @
2f8b4ce0
...
...
@@ -438,17 +438,20 @@ def _post_update_kernel(
for
i
in
range
(
num_sampled
):
token_id
=
tl
.
load
(
sampled_tokens_ptr
+
req_id
*
sampled_tokens_stride
+
i
)
token_ptr
=
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
token_id
)
count
=
tl
.
load
(
token_ptr
)
count
+=
1
tl
.
store
(
token_ptr
,
count
)
tl
.
store
(
all_token_ids_ptr
+
req_state_idx
*
all_token_ids_stride
+
total_len
+
i
,
token_id
,
)
if
output_bin_counts_ptr
is
not
None
:
token_ptr
=
(
output_bin_counts_ptr
+
req_state_idx
*
output_bin_counts_stride
+
token_id
)
count
=
tl
.
load
(
token_ptr
)
tl
.
store
(
token_ptr
,
count
+
1
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
req_id
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
req_id
+
1
)
query_len
=
query_end
-
query_start
...
...
@@ -467,7 +470,7 @@ def post_update(
# [max_num_reqs]
last_sampled_tokens
:
torch
.
Tensor
,
# [max_num_reqs, vocab_size]
output_bin_counts
:
torch
.
Tensor
,
output_bin_counts
:
torch
.
Tensor
|
None
,
# [num_reqs, num_speculative_steps + 1]
sampled_tokens
:
torch
.
Tensor
,
# [num_reqs]
...
...
@@ -487,7 +490,7 @@ def post_update(
num_computed_tokens
,
last_sampled_tokens
,
output_bin_counts
,
output_bin_counts
.
stride
(
0
),
output_bin_counts
.
stride
(
0
)
if
output_bin_counts
is
not
None
else
0
,
sampled_tokens
,
sampled_tokens
.
stride
(
0
),
num_sampled
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
2f8b4ce0
...
...
@@ -183,6 +183,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Draft tokens propagation - for spec-dec + struct outputs.
self
.
draft_tokens_handler
=
DraftTokensHandler
(
self
.
device
)
# Pooling models.
self
.
is_pooling_model
=
self
.
model_config
.
runner_type
==
"pooling"
self
.
pooling_runner
:
PoolingRunner
|
None
=
None
# General request states.
self
.
req_states
=
RequestState
(
max_num_reqs
=
self
.
max_num_reqs
,
...
...
@@ -199,6 +203,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_tokens
=
self
.
max_num_tokens
,
device
=
self
.
device
,
)
self
.
sampler
:
Sampler
|
None
=
None
self
.
rejection_sampler
:
RejectionSampler
|
None
=
None
self
.
prompt_logprobs_worker
:
PromptLogprobsWorker
|
None
=
None
self
.
structured_outputs_worker
:
StructuredOutputsWorker
|
None
=
None
if
self
.
is_last_pp_rank
and
not
self
.
is_pooling_model
:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self
.
sampler
=
Sampler
(
max_num_reqs
=
self
.
max_num_reqs
,
vocab_size
=
self
.
vocab_size
,
...
...
@@ -213,6 +226,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
use_strict_rejection_sampling
=
use_strict_rejection_sampling
,
)
self
.
prompt_logprobs_worker
=
PromptLogprobsWorker
(
self
.
max_num_reqs
)
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
)
# CUDA graphs.
self
.
decode_query_len
=
self
.
num_speculative_steps
+
1
...
...
@@ -222,21 +240,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
compilation_config
.
cudagraph_mode
,
decode_query_len
=
self
.
decode_query_len
,
)
# Structured outputs worker.
self
.
structured_outputs_worker
=
StructuredOutputsWorker
(
max_num_logits
=
self
.
max_num_reqs
*
(
self
.
num_speculative_steps
+
1
),
vocab_size
=
self
.
vocab_size
,
device
=
self
.
device
,
)
# LoRA-related workers.
self
.
lora_state
=
LoraState
(
max_num_reqs
=
self
.
max_num_reqs
)
# KV Connector if configured.
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
# Pooling models.
self
.
is_pooling_model
=
self
.
model_config
.
runner_type
==
"pooling"
self
.
pooling_runner
:
PoolingRunner
|
None
=
None
# For transferring state from execute_model to subsequent sample_tokens call.
self
.
execute_model_state
:
ExecuteModelState
|
None
=
None
...
...
@@ -248,8 +256,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
tasks
:
list
[
SupportedTask
]
=
[]
if
self
.
model_config
.
runner_type
==
"generate"
:
tasks
.
extend
(
self
.
model_state
.
get_supported_generation_tasks
())
if
self
.
pooling_runner
is
not
None
:
tasks
.
extend
(
self
.
pooling_runner
.
get_supported_pooling_tasks
())
if
self
.
is_pooling_model
:
# Do not rely on pooling_runner here, since this information is needed
# on the first PP rank, while pooling_runner is only initialized
# on the last PP rank.
tasks
.
extend
(
PoolingRunner
.
get_supported_tasks
(
self
.
model
))
return
tuple
(
tasks
)
def
load_model
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -289,7 +300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
model_state
=
init_model_state
(
self
.
vllm_config
,
self
.
model
,
self
.
encoder_cache
,
self
.
device
)
if
self
.
is_pooling_model
:
if
self
.
is_pooling_model
and
self
.
is_last_pp_rank
:
self
.
pooling_runner
=
PoolingRunner
(
self
.
model
)
def
get_model
(
self
)
->
nn
.
Module
:
...
...
@@ -420,6 +431,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if
self
.
speculator
is
not
None
:
assert
self
.
sampler
is
not
None
self
.
speculator
.
propose
(
input_batch
=
input_batch
,
attn_metadata
=
attn_metadata
,
...
...
@@ -457,10 +469,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self
.
sampler
(
logits
,
dummy_input_batch
,
)
assert
self
.
sampler
is
not
None
self
.
sampler
(
logits
,
dummy_input_batch
)
@
torch
.
inference_mode
()
def
_dummy_pooler_run
(
self
,
hidden_states
:
torch
.
Tensor
)
->
None
:
...
...
@@ -558,6 +568,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
req_states
.
remove_request
(
req_id
)
if
self
.
encoder_cache
is
not
None
:
self
.
encoder_cache
.
remove_request
(
req_id
)
if
self
.
prompt_logprobs_worker
is
not
None
:
self
.
prompt_logprobs_worker
.
remove_request
(
req_id
)
self
.
lora_state
.
remove_request
(
req_id
)
...
...
@@ -589,18 +600,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self
.
lora_state
.
add_request
(
req_id
,
req_index
,
new_req_data
.
lora_request
)
if
new_req_data
.
sampling_params
is
not
None
:
if
self
.
is_last_pp_rank
and
new_req_data
.
sampling_params
is
not
None
:
assert
self
.
sampler
is
not
None
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
assert
self
.
prompt_logprobs_worker
is
not
None
self
.
prompt_logprobs_worker
.
add_request
(
req_id
,
req_index
,
new_req_data
.
sampling_params
)
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
()
self
.
model_state
.
apply_staged_writes
()
if
self
.
sampler
is
not
None
:
self
.
sampler
.
apply_staged_writes
()
def
update_requests
(
self
,
scheduler_output
:
SchedulerOutput
)
->
None
:
# Add new blocks for the existing requests.
...
...
@@ -788,6 +802,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
)
if
grammar_output
is
not
None
:
# Apply grammar bitmask to the logits in-place.
assert
self
.
structured_outputs_worker
is
not
None
self
.
structured_outputs_worker
.
apply_grammar_bitmask
(
logits
,
input_batch
,
...
...
@@ -797,12 +812,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
input_batch
.
num_draft_tokens
==
0
:
# No draft tokens (common case).
sampler_output
=
self
.
sampler
(
logits
,
input_batch
,
)
assert
self
.
sampler
is
not
None
sampler_output
=
self
.
sampler
(
logits
,
input_batch
)
else
:
# Rejection sampling for spec decoding.
assert
self
.
rejection_sampler
is
not
None
sampler_output
=
self
.
rejection_sampler
(
logits
,
input_batch
,
...
...
@@ -831,11 +845,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected
:
torch
.
Tensor
,
)
->
None
:
# Update the number of computed tokens.
if
self
.
is_last_pp_rank
:
assert
self
.
sampler
is
not
None
output_bin_counts
=
self
.
sampler
.
penalties_state
.
output_bin_counts
else
:
output_bin_counts
=
None
post_update
(
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
self
.
req_states
.
last_sampled_tokens
,
self
.
sampler
.
penalties_state
.
output_bin_counts
,
output_bin_counts
,
sampled_tokens
,
num_sampled
,
num_rejected
,
...
...
@@ -1076,6 +1095,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast
(
sampler_output
.
sampled_token_ids
,
num_sampled
,
num_rejected
)
assert
self
.
prompt_logprobs_worker
is
not
None
prompt_logprobs_dict
=
self
.
prompt_logprobs_worker
.
compute_prompt_logprobs
(
self
.
model
.
compute_logits
,
hidden_states
,
...
...
@@ -1115,6 +1135,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch
,
sampler_output
.
sampled_token_ids
,
num_sampled
,
num_rejected
)
if
self
.
speculator
is
not
None
:
assert
self
.
sampler
is
not
None
draft_tokens
=
self
.
speculator
.
propose
(
input_batch
,
attn_metadata
,
...
...
vllm/v1/worker/gpu/pool/pooling_runner.py
View file @
2f8b4ce0
...
...
@@ -19,10 +19,11 @@ class PoolingRunner:
def
__init__
(
self
,
model
:
nn
.
Module
):
self
.
model
=
cast
(
VllmModelForPooling
,
model
)
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
if
not
is_pooling_model
(
self
.
model
):
@
staticmethod
def
get_supported_tasks
(
model
:
nn
.
Module
)
->
list
[
PoolingTask
]:
if
not
is_pooling_model
(
model
):
return
[]
assert
"embed"
in
self
.
model
.
pooler
.
get_supported_tasks
()
assert
"embed"
in
model
.
pooler
.
get_supported_tasks
()
return
[
"embed"
]
def
pool
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment