Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
876a16f4
Unverified
Commit
876a16f4
authored
Jan 30, 2026
by
Nick Hill
Committed by
GitHub
Jan 31, 2026
Browse files
[ModelRunner V2] Fix spec decoding + logprobs (#33391)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
aaa901ad
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
22 additions
and
5 deletions
+22
-5
tests/v1/engine/test_output_processor.py
tests/v1/engine/test_output_processor.py
+1
-0
vllm/v1/engine/logprobs.py
vllm/v1/engine/logprobs.py
+1
-1
vllm/v1/outputs.py
vllm/v1/outputs.py
+9
-1
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+2
-1
vllm/v1/worker/gpu/sample/logprob.py
vllm/v1/worker/gpu/sample/logprob.py
+2
-0
vllm/v1/worker/gpu/sample/sampler.py
vllm/v1/worker/gpu/sample/sampler.py
+6
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
tests/v1/engine/test_output_processor.py
View file @
876a16f4
...
@@ -335,6 +335,7 @@ def _validate_logprobs(
...
@@ -335,6 +335,7 @@ def _validate_logprobs(
ref_prompt_logprob_toks
,
ref_prompt_logprob_toks
,
ref_prompt_logprob_vals
,
ref_prompt_logprob_vals
,
ref_prompt_token_ranks
,
ref_prompt_token_ranks
,
_
,
)
=
ref_prompt_logprobs
)
=
ref_prompt_logprobs
for
idx
,
(
prompt_token
,
pos_logprob_dict
)
in
enumerate
(
for
idx
,
(
prompt_token
,
pos_logprob_dict
)
in
enumerate
(
zip
(
prompt_token_ids
[
1
:],
prompt_logprobs
[
1
:])
zip
(
prompt_token_ids
[
1
:],
prompt_logprobs
[
1
:])
...
...
vllm/v1/engine/logprobs.py
View file @
876a16f4
...
@@ -130,7 +130,7 @@ class LogprobsProcessor:
...
@@ -130,7 +130,7 @@ class LogprobsProcessor:
assert
self
.
num_prompt_logprobs
is
not
None
assert
self
.
num_prompt_logprobs
is
not
None
assert
self
.
prompt_logprobs
is
not
None
assert
self
.
prompt_logprobs
is
not
None
token_ids
,
logprobs
,
ranks
=
prompt_logprobs_tensors
token_ids
,
logprobs
,
ranks
,
_
=
prompt_logprobs_tensors
# Recover shapes.
# Recover shapes.
num_prompt_tokens
,
num_logprobs
=
logprobs
.
shape
num_prompt_tokens
,
num_logprobs
=
logprobs
.
shape
...
...
vllm/v1/outputs.py
View file @
876a16f4
...
@@ -51,13 +51,17 @@ class LogprobsTensors(NamedTuple):
...
@@ -51,13 +51,17 @@ class LogprobsTensors(NamedTuple):
logprobs
:
torch
.
Tensor
logprobs
:
torch
.
Tensor
# [num_reqs x num_generated_tokens]
# [num_reqs x num_generated_tokens]
selected_token_ranks
:
torch
.
Tensor
selected_token_ranks
:
torch
.
Tensor
# [num_reqs]
cu_num_generated_tokens
:
list
[
int
]
|
None
=
None
def
tolists
(
self
,
cu_num_generated_tokens
:
list
[
int
]
|
None
=
None
):
def
tolists
(
self
,
cu_num_generated_tokens
:
list
[
int
]
|
None
=
None
):
return
LogprobsLists
(
return
LogprobsLists
(
self
.
logprob_token_ids
.
cpu
().
numpy
(),
self
.
logprob_token_ids
.
cpu
().
numpy
(),
self
.
logprobs
.
cpu
().
numpy
(),
self
.
logprobs
.
cpu
().
numpy
(),
self
.
selected_token_ranks
.
cpu
().
numpy
(),
self
.
selected_token_ranks
.
cpu
().
numpy
(),
cu_num_generated_tokens
,
cu_num_generated_tokens
if
cu_num_generated_tokens
is
not
None
else
self
.
cu_num_generated_tokens
,
)
)
def
to_cpu_nonblocking
(
self
)
->
"LogprobsTensors"
:
def
to_cpu_nonblocking
(
self
)
->
"LogprobsTensors"
:
...
@@ -67,10 +71,14 @@ class LogprobsTensors(NamedTuple):
...
@@ -67,10 +71,14 @@ class LogprobsTensors(NamedTuple):
self
.
logprob_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
logprob_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
logprobs
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
logprobs
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
selected_token_ranks
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
selected_token_ranks
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
cu_num_generated_tokens
,
)
)
def
filter
(
self
,
mask
:
torch
.
Tensor
)
->
"LogprobsTensors"
:
def
filter
(
self
,
mask
:
torch
.
Tensor
)
->
"LogprobsTensors"
:
"""Filter the logprobs tensors with the given bool mask."""
"""Filter the logprobs tensors with the given bool mask."""
assert
self
.
cu_num_generated_tokens
is
None
,
(
"filter can't be used with cu_num_generated_tokens"
)
return
LogprobsTensors
(
return
LogprobsTensors
(
self
.
logprob_token_ids
[
mask
],
self
.
logprob_token_ids
[
mask
],
self
.
logprobs
[
mask
],
self
.
logprobs
[
mask
],
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
876a16f4
...
@@ -316,7 +316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -316,7 +316,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
# during actual execution.
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
pos
)
self
.
sampler
(
logits
,
idx_mapping
,
idx_mapping_np
,
idx_mapping_np
,
pos
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
...
@@ -686,6 +686,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -686,6 +686,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits
,
logits
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
expanded_idx_mapping
,
input_batch
.
idx_mapping_np
,
input_batch
.
idx_mapping_np
,
input_batch
.
cu_num_logits_np
,
sample_pos
,
sample_pos
,
)
)
...
...
vllm/v1/worker/gpu/sample/logprob.py
View file @
876a16f4
...
@@ -103,6 +103,7 @@ def compute_topk_logprobs(
...
@@ -103,6 +103,7 @@ def compute_topk_logprobs(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
num_logprobs
:
int
,
num_logprobs
:
int
,
sampled_token_ids
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
cu_num_logits
:
list
[
int
]
|
None
=
None
,
)
->
LogprobsTensors
:
)
->
LogprobsTensors
:
assert
num_logprobs
>=
0
assert
num_logprobs
>=
0
batch_size
,
vocab_size
=
logits
.
shape
batch_size
,
vocab_size
=
logits
.
shape
...
@@ -135,4 +136,5 @@ def compute_topk_logprobs(
...
@@ -135,4 +136,5 @@ def compute_topk_logprobs(
logprob_token_ids
=
logprob_token_ids
,
logprob_token_ids
=
logprob_token_ids
,
logprobs
=
logprobs
,
logprobs
=
logprobs
,
selected_token_ranks
=
token_ranks
,
selected_token_ranks
=
token_ranks
,
cu_num_generated_tokens
=
cu_num_logits
,
)
)
vllm/v1/worker/gpu/sample/sampler.py
View file @
876a16f4
...
@@ -62,6 +62,7 @@ class Sampler:
...
@@ -62,6 +62,7 @@ class Sampler:
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping
:
torch
.
Tensor
,
idx_mapping_np
:
np
.
ndarray
,
idx_mapping_np
:
np
.
ndarray
,
cu_num_logits_np
:
np
.
ndarray
,
pos
:
torch
.
Tensor
,
pos
:
torch
.
Tensor
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
...
@@ -78,7 +79,11 @@ class Sampler:
...
@@ -78,7 +79,11 @@ class Sampler:
if
self
.
logprobs_mode
==
"processed_logprobs"
if
self
.
logprobs_mode
==
"processed_logprobs"
else
logits
else
logits
)
)
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
max_num_logprobs
,
sampled
)
expanded_logits
=
logits
.
shape
[
0
]
!=
idx_mapping_np
.
shape
[
0
]
cu_num_logits
=
cu_num_logits_np
.
tolist
()
if
expanded_logits
else
None
logprobs_tensors
=
compute_topk_logprobs
(
logits
,
max_num_logprobs
,
sampled
,
cu_num_logits
)
else
:
else
:
logprobs_tensors
=
None
logprobs_tensors
=
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
876a16f4
...
@@ -4449,7 +4449,7 @@ class GPUModelRunner(
...
@@ -4449,7 +4449,7 @@ class GPUModelRunner(
# Compute prompt logprobs.
# Compute prompt logprobs.
logprobs
=
self
.
sampler
.
compute_logprobs
(
logits
)
logprobs
=
self
.
sampler
.
compute_logprobs
(
logits
)
token_ids
,
logprobs
,
ranks
=
self
.
sampler
.
gather_logprobs
(
token_ids
,
logprobs
,
ranks
,
_
=
self
.
sampler
.
gather_logprobs
(
logprobs
,
num_prompt_logprobs
,
tgt_token_ids
logprobs
,
num_prompt_logprobs
,
tgt_token_ids
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment