Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3aee6573
Unverified
Commit
3aee6573
authored
Mar 24, 2025
by
Nick Hill
Committed by
GitHub
Mar 24, 2025
Browse files
[V1] Aggregate chunked prompt logprobs in model runner (#14875)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
9cc64514
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
68 additions
and
44 deletions
+68
-44
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-2
vllm/v1/engine/logprobs.py
vllm/v1/engine/logprobs.py
+0
-1
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+3
-18
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-15
vllm/v1/outputs.py
vllm/v1/outputs.py
+19
-0
vllm/v1/worker/gpu_input_batch.py
vllm/v1/worker/gpu_input_batch.py
+5
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+33
-8
No files found.
vllm/v1/core/sched/scheduler.py
View file @
3aee6573
...
@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
...
@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if
new_token_ids
:
if
new_token_ids
or
prompt_logprobs_tensors
is
not
None
:
# Add EngineCoreOutput for this Request.
# Add EngineCoreOutput for this Request.
outputs
.
append
(
outputs
.
append
(
EngineCoreOutput
(
EngineCoreOutput
(
...
@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
...
@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
stop_reason
=
request
.
stop_reason
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
()))
events
=
request
.
take_events
()))
else
:
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
self
.
scheduled_req_ids
.
remove
(
request
.
request_id
)
self
.
scheduled_req_ids
.
remove
(
request
.
request_id
)
if
not
stopped
:
if
not
stopped
:
...
...
vllm/v1/engine/logprobs.py
View file @
3aee6573
...
@@ -115,7 +115,6 @@ class LogprobsProcessor:
...
@@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens
,
num_logprobs
=
logprobs
.
shape
num_prompt_tokens
,
num_logprobs
=
logprobs
.
shape
# Pythonize the torch tensors.
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks
=
ranks
.
tolist
()
prompt_token_ranks
=
ranks
.
tolist
()
prompt_logprobs
=
logprobs
.
tolist
()
prompt_logprobs
=
logprobs
.
tolist
()
token_ids
=
token_ids
.
tolist
()
token_ids
=
token_ids
.
tolist
()
...
...
vllm/v1/engine/output_processor.py
View file @
3aee6573
...
@@ -105,9 +105,7 @@ class RequestState:
...
@@ -105,9 +105,7 @@ class RequestState:
finished
=
finish_reason
is
not
None
finished
=
finish_reason
is
not
None
final_only
=
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
final_only
=
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
if
not
finished
and
final_only
:
# does not stream partial prefills.
if
not
finished
and
(
self
.
is_prefilling
or
final_only
):
# Only the final output is required in FINAL_ONLY mode.
# Only the final output is required in FINAL_ONLY mode.
return
None
return
None
...
@@ -285,19 +283,7 @@ class OutputProcessor:
...
@@ -285,19 +283,7 @@ class OutputProcessor:
finish_reason
=
engine_core_output
.
finish_reason
finish_reason
=
engine_core_output
.
finish_reason
stop_reason
=
engine_core_output
.
stop_reason
stop_reason
=
engine_core_output
.
stop_reason
# TODO(andy): prompt logprobs + chunked prefill can
req_state
.
is_prefilling
=
False
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state
.
is_prefilling
=
not
new_token_ids
# 2) Detokenize the token ids into text and perform stop checks.
# 2) Detokenize the token ids into text and perform stop checks.
stop_string
=
req_state
.
detokenizer
.
update
(
stop_string
=
req_state
.
detokenizer
.
update
(
...
@@ -306,8 +292,7 @@ class OutputProcessor:
...
@@ -306,8 +292,7 @@ class OutputProcessor:
finish_reason
=
FinishReason
.
STOP
finish_reason
=
FinishReason
.
STOP
stop_reason
=
stop_string
stop_reason
=
stop_string
# 3) Compute sample and prompt logprobs for request,
# 3) Compute sample and prompt logprobs for request, if required.
# if required.
req_state
.
logprobs_processor
.
update_from_output
(
engine_core_output
)
req_state
.
logprobs_processor
.
update_from_output
(
engine_core_output
)
# 4) Create and handle RequestOutput objects.
# 4) Create and handle RequestOutput objects.
...
...
vllm/v1/metrics/stats.py
View file @
3aee6573
...
@@ -100,15 +100,8 @@ class IterationStats:
...
@@ -100,15 +100,8 @@ class IterationStats:
num_new_generation_tokens
=
len
(
output
.
new_token_ids
)
num_new_generation_tokens
=
len
(
output
.
new_token_ids
)
self
.
num_generation_tokens
+=
num_new_generation_tokens
self
.
num_generation_tokens
+=
num_new_generation_tokens
if
is_prefilling
and
num_new_generation_tokens
>
0
:
if
is_prefilling
:
# TODO(andy): we used to assert that num_new_generation_tokens
assert
num_new_generation_tokens
>
0
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
self
.
num_prompt_tokens
+=
prompt_len
self
.
num_prompt_tokens
+=
prompt_len
first_token_latency
=
self
.
_time_since
(
req_stats
.
arrival_time
)
first_token_latency
=
self
.
_time_since
(
req_stats
.
arrival_time
)
...
@@ -123,15 +116,11 @@ class IterationStats:
...
@@ -123,15 +116,11 @@ class IterationStats:
# Process the batch-level "new tokens" engine core event
# Process the batch-level "new tokens" engine core event
if
is_prefilling
:
if
is_prefilling
:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if
num_new_generation_tokens
>
0
:
req_stats
.
first_token_ts
=
engine_core_timestamp
req_stats
.
first_token_ts
=
engine_core_timestamp
else
:
else
:
tpot
=
engine_core_timestamp
-
req_stats
.
last_token_ts
tpot
=
engine_core_timestamp
-
req_stats
.
last_token_ts
self
.
time_per_output_tokens_iter
.
append
(
tpot
)
self
.
time_per_output_tokens_iter
.
append
(
tpot
)
# TODO: re-enable no-output-for-partial-prefills invariant as above
if
num_new_generation_tokens
>
0
:
req_stats
.
last_token_ts
=
engine_core_timestamp
req_stats
.
last_token_ts
=
engine_core_timestamp
def
update_from_events
(
self
,
req_id
:
str
,
events
:
list
[
"EngineCoreEvent"
],
def
update_from_events
(
self
,
req_id
:
str
,
events
:
list
[
"EngineCoreEvent"
],
...
...
vllm/v1/outputs.py
View file @
3aee6573
...
@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
...
@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
self
.
selected_token_ranks
.
tolist
(),
self
.
selected_token_ranks
.
tolist
(),
)
)
@
staticmethod
def
empty_cpu
(
num_positions
:
int
,
num_tokens_per_position
:
int
)
->
"LogprobsTensors"
:
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids
=
torch
.
empty
(
(
num_positions
,
num_tokens_per_position
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
logprobs
=
torch
.
empty_like
(
logprob_token_ids
,
dtype
=
torch
.
float32
)
selected_token_ranks
=
torch
.
empty
(
num_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
return
LogprobsTensors
(
logprob_token_ids
=
logprob_token_ids
,
logprobs
=
logprobs
,
selected_token_ranks
=
selected_token_ranks
,
)
@
dataclass
@
dataclass
class
SamplerOutput
:
class
SamplerOutput
:
...
...
vllm/v1/worker/gpu_input_batch.py
View file @
3aee6573
...
@@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
...
@@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
swap_dict_values
from
vllm.utils
import
swap_dict_values
from
vllm.v1.outputs
import
LogprobsTensors
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.utils
import
copy_slice
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -197,6 +198,9 @@ class InputBatch:
...
@@ -197,6 +198,9 @@ class InputBatch:
# that are currently in the prefill phase.
# that are currently in the prefill phase.
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
self
.
num_prompt_logprobs
:
dict
[
str
,
int
]
=
{}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self
.
in_progress_prompt_logprobs_cpu
:
dict
[
str
,
LogprobsTensors
]
=
{}
self
.
logit_bias
:
list
[
Optional
[
dict
[
int
,
self
.
logit_bias
:
list
[
Optional
[
dict
[
int
,
float
]]]
=
[
None
]
*
max_num_reqs
float
]]]
=
[
None
]
*
max_num_reqs
self
.
has_allowed_token_ids
:
set
[
str
]
=
set
()
self
.
has_allowed_token_ids
:
set
[
str
]
=
set
()
...
@@ -362,6 +366,7 @@ class InputBatch:
...
@@ -362,6 +366,7 @@ class InputBatch:
self
.
generators
.
pop
(
req_index
,
None
)
self
.
generators
.
pop
(
req_index
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
num_prompt_logprobs
.
pop
(
req_id
,
None
)
self
.
in_progress_prompt_logprobs_cpu
.
pop
(
req_id
,
None
)
# LoRA
# LoRA
lora_id
=
self
.
request_lora_mapping
[
req_index
]
lora_id
=
self
.
request_lora_mapping
[
req_index
]
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3aee6573
...
@@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
not
num_prompt_logprobs_dict
:
if
not
num_prompt_logprobs_dict
:
return
{}
return
{}
in_progress_dict
=
self
.
input_batch
.
in_progress_prompt_logprobs_cpu
prompt_logprobs_dict
:
dict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
prompt_logprobs_dict
:
dict
[
str
,
Optional
[
LogprobsTensors
]]
=
{}
# Since prompt logprobs are a rare feature, prioritize simple,
# Since prompt logprobs are a rare feature, prioritize simple,
...
@@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prompt_token_ids
=
torch
.
tensor
(
request
.
prompt_token_ids
).
to
(
prompt_token_ids
=
torch
.
tensor
(
request
.
prompt_token_ids
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
# Set up target LogprobsTensors object.
logprobs_tensors
=
in_progress_dict
.
get
(
req_id
)
if
not
logprobs_tensors
:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors
=
LogprobsTensors
.
empty_cpu
(
num_prompt_tokens
-
1
,
num_prompt_logprobs
+
1
)
in_progress_dict
[
req_id
]
=
logprobs_tensors
# Determine number of logits to retrieve.
# Determine number of logits to retrieve.
start_tok
=
request
.
num_computed_tokens
+
1
start_idx
=
request
.
num_computed_tokens
start_tok
=
start_idx
+
1
num_remaining_tokens
=
num_prompt_tokens
-
start_tok
num_remaining_tokens
=
num_prompt_tokens
-
start_tok
if
num_tokens
<
num_remaining_tokens
:
if
num_tokens
<
=
num_remaining_tokens
:
# This is a chunk, more tokens remain.
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits
=
num_tokens
num_logits
=
num_tokens
else
:
else
:
# This is the last chunk of prompt tokens to return.
# This is the last chunk of prompt tokens to return.
num_logits
=
num_remaining_tokens
num_logits
=
num_remaining_tokens
completed_prefill_reqs
.
append
(
req_id
)
completed_prefill_reqs
.
append
(
req_id
)
prompt_logprobs_dict
[
req_id
]
=
logprobs_tensors
if
num_logits
<=
0
:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue
# Get the logits corresponding to this req's prompt tokens.
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# If this is a partial request (i.e. chunked prefill),
...
@@ -1236,18 +1257,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1236,18 +1257,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
,
num_prompt_logprobs
,
tgt_token_ids
)
logprobs
,
num_prompt_logprobs
,
tgt_token_ids
)
# Transfer GPU->CPU async.
# Transfer GPU->CPU async.
prompt_logprobs_dict
[
req_id
]
=
LogprobsTensors
(
chunk_slice
=
slice
(
start_idx
,
start_idx
+
num_logits
)
token_ids
.
to
(
"cpu"
,
non_blocking
=
True
),
logprobs_tensors
.
logprob_token_ids
[
chunk_slice
].
copy_
(
logprobs
.
to
(
"cpu"
,
non_blocking
=
True
),
token_ids
,
non_blocking
=
True
)
ranks
.
to
(
"cpu"
,
non_blocking
=
True
),
logprobs_tensors
.
logprobs
[
chunk_slice
].
copy_
(
logprobs
,
)
non_blocking
=
True
)
logprobs_tensors
.
selected_token_ranks
[
chunk_slice
].
copy_
(
ranks
,
non_blocking
=
True
)
# Remove requests that have completed prefill from the batch
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
# num_prompt_logprobs_dict.
for
req_id
in
completed_prefill_reqs
:
for
req_id
in
completed_prefill_reqs
:
del
num_prompt_logprobs_dict
[
req_id
]
del
num_prompt_logprobs_dict
[
req_id
]
del
in_progress_dict
[
req_id
]
# Must synchronize the non-blocking GPU->CPU transfers.
# Must synchronize the non-blocking GPU->CPU transfers.
if
prompt_logprobs_dict
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
prompt_logprobs_dict
return
prompt_logprobs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment