Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf6a3d0f
Unverified
Commit
bf6a3d0f
authored
Nov 10, 2025
by
Wei Wei
Committed by
GitHub
Nov 10, 2025
Browse files
[Misc] Add more scoping for improved trace (#28329)
Signed-off-by:
Wei Wei
<
wwei6@meta.com
>
parent
40d33264
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
192 additions
and
148 deletions
+192
-148
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+61
-55
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+73
-44
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+21
-16
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+37
-33
No files found.
vllm/v1/core/sched/scheduler.py
View file @
bf6a3d0f
...
@@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
...
@@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
record_function_or_nullcontext
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
...
@@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
continue
continue
# Schedule newly needed KV blocks for the request.
# Schedule newly needed KV blocks for the request.
while
True
:
with
record_function_or_nullcontext
(
"schedule: allocate_slots"
):
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
while
True
:
request
,
new_blocks
=
self
.
kv_cache_manager
.
allocate_slots
(
num_new_tokens
,
request
,
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
num_new_tokens
,
)
num_lookahead_tokens
=
self
.
num_lookahead_tokens
,
if
new_blocks
is
not
None
:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if
self
.
policy
==
SchedulingPolicy
.
PRIORITY
:
preempted_req
=
max
(
self
.
running
,
key
=
lambda
r
:
(
r
.
priority
,
r
.
arrival_time
),
)
)
self
.
running
.
remove
(
preempted_req
)
if
preempted_req
in
scheduled_running_reqs
:
scheduled_running_reqs
.
remove
(
preempted_req
)
token_budget
+=
num_scheduled_tokens
[
preempted_req
.
request_id
]
req_to_new_blocks
.
pop
(
preempted_req
.
request_id
)
num_scheduled_tokens
.
pop
(
preempted_req
.
request_id
)
req_index
-=
1
else
:
preempted_req
=
self
.
running
.
pop
()
self
.
kv_cache_manager
.
free
(
preempted_req
)
if
new_blocks
is
not
None
:
self
.
encoder_cache_manager
.
free
(
preempted_req
)
# The request can be scheduled.
preempted_req
.
status
=
RequestStatus
.
PREEMPTED
break
preempted_req
.
num_computed_tokens
=
0
preempted_req
.
num_preemptions
+=
1
if
self
.
log_stats
:
preempted_req
.
record_event
(
EngineCoreEventType
.
PREEMPTED
,
scheduled_timestamp
)
self
.
waiting
.
prepend_request
(
preempted_req
)
# The request cannot be scheduled.
preempted_reqs
.
append
(
preempted_req
)
# Preempt the lowest-priority request.
if
preempted_req
==
request
:
if
self
.
policy
==
SchedulingPolicy
.
PRIORITY
:
# No more request to preempt. Cannot schedule this request.
preempted_req
=
max
(
break
self
.
running
,
key
=
lambda
r
:
(
r
.
priority
,
r
.
arrival_time
),
)
self
.
running
.
remove
(
preempted_req
)
if
preempted_req
in
scheduled_running_reqs
:
scheduled_running_reqs
.
remove
(
preempted_req
)
token_budget
+=
num_scheduled_tokens
[
preempted_req
.
request_id
]
req_to_new_blocks
.
pop
(
preempted_req
.
request_id
)
num_scheduled_tokens
.
pop
(
preempted_req
.
request_id
)
req_index
-=
1
else
:
preempted_req
=
self
.
running
.
pop
()
self
.
kv_cache_manager
.
free
(
preempted_req
)
self
.
encoder_cache_manager
.
free
(
preempted_req
)
preempted_req
.
status
=
RequestStatus
.
PREEMPTED
preempted_req
.
num_computed_tokens
=
0
preempted_req
.
num_preemptions
+=
1
if
self
.
log_stats
:
preempted_req
.
record_event
(
EngineCoreEventType
.
PREEMPTED
,
scheduled_timestamp
)
self
.
waiting
.
prepend_request
(
preempted_req
)
preempted_reqs
.
append
(
preempted_req
)
if
preempted_req
==
request
:
# No more request to preempt. Cannot schedule this request.
break
if
new_blocks
is
None
:
if
new_blocks
is
None
:
# Cannot schedule this request.
# Cannot schedule this request.
...
@@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
...
@@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
# This can be potentially used for cascade attention.
num_common_prefix_blocks
=
[
0
]
*
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
num_common_prefix_blocks
=
[
0
]
*
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
if
self
.
running
:
with
record_function_or_nullcontext
(
"schedule: get_num_common_prefix_blocks"
):
any_request
=
self
.
running
[
0
]
if
self
.
running
:
num_common_prefix_blocks
=
(
any_request
=
self
.
running
[
0
]
self
.
kv_cache_manager
.
get_num_common_prefix_blocks
(
num_common_prefix_blocks
=
(
any_request
.
request_id
self
.
kv_cache_manager
.
get_num_common_prefix_blocks
(
any_request
.
request_id
)
)
)
)
# Construct the scheduler output.
# Construct the scheduler output.
new_reqs_data
=
[
new_reqs_data
=
[
...
@@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
...
@@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
)
)
for
req
in
scheduled_new_reqs
for
req
in
scheduled_new_reqs
]
]
cached_reqs_data
=
self
.
_make_cached_request_data
(
with
record_function_or_nullcontext
(
"schedule: make_cached_request_data"
):
scheduled_running_reqs
,
cached_reqs_data
=
self
.
_make_cached_request_data
(
scheduled_resumed_reqs
,
scheduled_running_reqs
,
num_scheduled_tokens
,
scheduled_resumed_reqs
,
scheduled_spec_decode_tokens
,
num_scheduled_tokens
,
req_to_new_blocks
,
scheduled_spec_decode_tokens
,
)
req_to_new_blocks
,
)
# Record the request ids that were scheduled in this step.
# Record the request ids that were scheduled in this step.
self
.
prev_step_scheduled_req_ids
.
clear
()
self
.
prev_step_scheduled_req_ids
.
clear
()
...
@@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
...
@@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
if
self
.
connector
is
not
None
:
if
self
.
connector
is
not
None
:
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
scheduler_output
.
kv_connector_metadata
=
meta
with
record_function_or_nullcontext
(
"schedule: update_after_schedule"
):
self
.
_update_after_schedule
(
scheduler_output
)
self
.
_update_after_schedule
(
scheduler_output
)
return
scheduler_output
return
scheduler_output
def
_update_after_schedule
(
def
_update_after_schedule
(
...
...
vllm/v1/engine/core.py
View file @
bf6a3d0f
...
@@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
...
@@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.serial_utils
import
MsgpackDecoder
,
MsgpackEncoder
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -315,17 +316,21 @@ class EngineCore:
...
@@ -315,17 +316,21 @@ class EngineCore:
# or finished and not yet removed from the batch.
# or finished and not yet removed from the batch.
if
not
self
.
scheduler
.
has_requests
():
if
not
self
.
scheduler
.
has_requests
():
return
{},
False
return
{},
False
scheduler_output
=
self
.
scheduler
.
schedule
()
with
record_function_or_nullcontext
(
"core step: schedule"
):
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
scheduler_output
=
self
.
scheduler
.
schedule
()
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
with
self
.
log_error_detail
(
scheduler_output
):
with
record_function_or_nullcontext
(
"core step: execute_model"
):
model_output
=
future
.
result
()
future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
if
model_output
is
None
:
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
model_output
=
self
.
model_executor
.
sample_tokens
(
grammar_output
)
with
self
.
log_error_detail
(
scheduler_output
):
model_output
=
future
.
result
()
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
if
model_output
is
None
:
scheduler_output
,
model_output
model_output
=
self
.
model_executor
.
sample_tokens
(
grammar_output
)
)
with
record_function_or_nullcontext
(
"core step: update_from_output"
):
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
return
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
return
engine_core_outputs
,
scheduler_output
.
total_num_scheduled_tokens
>
0
...
@@ -363,32 +368,49 @@ class EngineCore:
...
@@ -363,32 +368,49 @@ class EngineCore:
model_executed
=
False
model_executed
=
False
deferred_scheduler_output
=
None
deferred_scheduler_output
=
None
if
self
.
scheduler
.
has_requests
():
if
self
.
scheduler
.
has_requests
():
scheduler_output
=
self
.
scheduler
.
schedule
()
with
record_function_or_nullcontext
(
"core step_with_batch_queue: schedule"
):
exec_future
=
self
.
model_executor
.
execute_model
(
scheduler_output
=
self
.
scheduler
.
schedule
()
scheduler_output
,
non_block
=
True
with
record_function_or_nullcontext
(
)
"core step_with_batch_queue: execute_model"
):
exec_future
=
self
.
model_executor
.
execute_model
(
scheduler_output
,
non_block
=
True
)
model_executed
=
scheduler_output
.
total_num_scheduled_tokens
>
0
model_executed
=
scheduler_output
.
total_num_scheduled_tokens
>
0
if
scheduler_output
.
pending_structured_output_tokens
:
if
scheduler_output
.
pending_structured_output_tokens
:
# We need to defer sampling until we have processed the model output
with
record_function_or_nullcontext
(
# from the prior step.
"core step_with_batch_queue: pending_structured_output_tokens"
deferred_scheduler_output
=
scheduler_output
):
# Block-wait for execute to return (continues running async on the GPU).
# We need to defer sampling until we have processed the model output
with
self
.
log_error_detail
(
scheduler_output
):
# from the prior step.
exec_result
=
exec_future
.
result
()
deferred_scheduler_output
=
scheduler_output
assert
exec_result
is
None
# Block-wait for execute to return
# (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
assert
exec_result
is
None
else
:
else
:
# We aren't waiting for any tokens, get any grammar output immediately.
with
record_function_or_nullcontext
(
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
# Block-wait for execute to return (continues running async on the GPU).
with
self
.
log_error_detail
(
scheduler_output
):
with
self
.
log_error_detail
(
scheduler_output
):
exec_result
=
exec_future
.
result
()
exec_result
=
exec_future
.
result
()
if
exec_result
is
None
:
if
exec_result
is
None
:
# Call sample tokens.
with
record_function_or_nullcontext
(
future
=
self
.
model_executor
.
sample_tokens
(
"core step_with_batch_queue: sample_tokens"
grammar_output
,
non_block
=
True
):
)
# Call sample tokens.
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
else
:
else
:
# No sampling required (e.g. all requests finished).
# No sampling required (e.g. all requests finished).
future
=
cast
(
Future
[
ModelRunnerOutput
],
exec_future
)
future
=
cast
(
Future
[
ModelRunnerOutput
],
exec_future
)
...
@@ -408,27 +430,34 @@ class EngineCore:
...
@@ -408,27 +430,34 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# only be called when the scheduler contains requests or the queue
# is non-empty.
# is non-empty.
return
None
,
False
return
None
,
False
with
record_function_or_nullcontext
(
"core step_with_batch_queue: model_output"
):
# Block until the next result is available.
# Block until the next result is available.
future
,
scheduler_output
=
batch_queue
.
pop
()
future
,
scheduler_output
=
batch_queue
.
pop
()
with
self
.
log_error_detail
(
scheduler_output
):
with
self
.
log_error_detail
(
scheduler_output
):
model_output
=
future
.
result
()
model_output
=
future
.
result
()
with
record_function_or_nullcontext
(
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
"core step_with_batch_queue: update_from_output"
scheduler_output
,
model_output
):
)
engine_core_outputs
=
self
.
scheduler
.
update_from_output
(
scheduler_output
,
model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if
deferred_scheduler_output
:
if
deferred_scheduler_output
:
# We now have the tokens needed to compute the bitmask for the
with
record_function_or_nullcontext
(
# deferred request. Get the bitmask and call sample tokens.
"core step_with_batch_queue: deferred_scheduler_output"
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
):
deferred_scheduler_output
# We now have the tokens needed to compute the bitmask for the
)
# deferred request. Get the bitmask and call sample tokens.
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
grammar_output
=
self
.
scheduler
.
get_grammar_bitmask
(
batch_queue
.
appendleft
((
future
,
deferred_scheduler_output
))
deferred_scheduler_output
)
future
=
self
.
model_executor
.
sample_tokens
(
grammar_output
,
non_block
=
True
)
batch_queue
.
appendleft
((
future
,
deferred_scheduler_output
))
return
engine_core_outputs
,
model_executed
return
engine_core_outputs
,
model_executed
...
...
vllm/v1/engine/llm_engine.py
View file @
bf6a3d0f
...
@@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
...
@@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
from
vllm.v1.metrics.loggers
import
StatLoggerFactory
,
StatLoggerManager
from
vllm.v1.metrics.loggers
import
StatLoggerFactory
,
StatLoggerManager
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.reader
import
Metric
,
get_metrics_snapshot
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.metrics.stats
import
IterationStats
from
vllm.v1.utils
import
record_function_or_nullcontext
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -280,28 +281,32 @@ class LLMEngine:
...
@@ -280,28 +281,32 @@ class LLMEngine:
return
[]
return
[]
# 1) Get EngineCoreOutput from the EngineCore.
# 1) Get EngineCoreOutput from the EngineCore.
outputs
=
self
.
engine_core
.
get_output
()
with
record_function_or_nullcontext
(
"llm_genine step: get_output"
):
outputs
=
self
.
engine_core
.
get_output
()
# 2) Process EngineCoreOutputs.
# 2) Process EngineCoreOutputs.
iteration_stats
=
IterationStats
()
if
self
.
log_stats
else
None
with
record_function_or_nullcontext
(
"llm_genine step: process_outputs"
):
processed_outputs
=
self
.
output_processor
.
process_outputs
(
iteration_stats
=
IterationStats
()
if
self
.
log_stats
else
None
outputs
.
outputs
,
processed_outputs
=
self
.
output_processor
.
process_outputs
(
engine_core_timestamp
=
outputs
.
timestamp
,
outputs
.
outputs
,
iteration_stats
=
iteration_stats
,
engine_core_timestamp
=
outputs
.
timestamp
,
)
iteration_stats
=
iteration_stats
,
self
.
output_processor
.
update_scheduler_stats
(
outputs
.
scheduler_stats
)
)
self
.
output_processor
.
update_scheduler_stats
(
outputs
.
scheduler_stats
)
# 3) Abort any reqs that finished due to stop strings.
# 3) Abort any reqs that finished due to stop strings.
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
with
record_function_or_nullcontext
(
"llm_genine step: abort_requests"
):
self
.
engine_core
.
abort_requests
(
processed_outputs
.
reqs_to_abort
)
# 4) Record stats
# 4) Record stats
if
self
.
logger_manager
is
not
None
and
outputs
.
scheduler_stats
is
not
None
:
with
record_function_or_nullcontext
(
"llm_genine step: record_stats"
):
self
.
logger_manager
.
record
(
if
self
.
logger_manager
is
not
None
and
outputs
.
scheduler_stats
is
not
None
:
scheduler_stats
=
outputs
.
scheduler_stats
,
self
.
logger_manager
.
record
(
iteration_stats
=
iteration_stats
,
scheduler_stats
=
outputs
.
scheduler_stats
,
mm_cache_stats
=
self
.
processor
.
stat_mm_cache
(),
iteration_stats
=
iteration_stats
,
)
mm_cache_stats
=
self
.
processor
.
stat_mm_cache
(),
self
.
do_log_stats_with_interval
()
)
self
.
do_log_stats_with_interval
()
return
processed_outputs
.
request_outputs
return
processed_outputs
.
request_outputs
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
bf6a3d0f
...
@@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"after execute_model() returns None."
"after execute_model() returns None."
)
)
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
total_num_scheduled_tokens
with
record_function_or_nullcontext
(
"
P
reprocess"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: p
reprocess"
):
with
self
.
synchronize_input_prep
():
with
self
.
synchronize_input_prep
():
# Update persistent batch states.
# Update persistent batch states.
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
...
@@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor
=
batch_descriptor
,
batch_descriptor
=
batch_descriptor
,
ubatch_slices
=
ubatch_slices
,
ubatch_slices
=
ubatch_slices
,
),
),
record_function_or_nullcontext
(
"
F
orward"
),
record_function_or_nullcontext
(
"
gpu_model_runner: f
orward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
):
):
model_output
=
self
.
_model_forward
(
model_output
=
self
.
_model_forward
(
...
@@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**
model_kwargs
,
**
model_kwargs
,
)
)
with
record_function_or_nullcontext
(
"
P
ostprocess"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: p
ostprocess"
):
if
self
.
use_aux_hidden_state_outputs
:
if
self
.
use_aux_hidden_state_outputs
:
# True when EAGLE 3 is used.
# True when EAGLE 3 is used.
hidden_states
,
aux_hidden_states
=
model_output
hidden_states
,
aux_hidden_states
=
model_output
...
@@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output
,
grammar_output
,
self
.
input_batch
,
logits
scheduler_output
,
grammar_output
,
self
.
input_batch
,
logits
)
)
with
record_function_or_nullcontext
(
"
S
ample"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: s
ample"
):
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
sampler_output
=
self
.
_sample
(
logits
,
spec_decode_metadata
)
def
propose_draft_token_ids
(
sampled_token_ids
):
def
propose_draft_token_ids
(
sampled_token_ids
):
assert
spec_decode_common_attn_metadata
is
not
None
assert
spec_decode_common_attn_metadata
is
not
None
with
record_function_or_nullcontext
(
"
D
raft"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: d
raft"
):
self
.
_draft_token_ids
=
self
.
propose_draft_token_ids
(
self
.
_draft_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
scheduler_output
,
sampled_token_ids
,
sampled_token_ids
,
...
@@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# as inputs, and does not need to wait for bookkeeping to finish.
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids
(
sampler_output
.
sampled_token_ids
)
propose_draft_token_ids
(
sampler_output
.
sampled_token_ids
)
with
record_function_or_nullcontext
(
"
B
ookkeep"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: b
ookkeep"
):
(
(
num_nans_in_logits
,
num_nans_in_logits
,
logprobs_lists
,
logprobs_lists
,
...
@@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# tokens on the CPU, so they are run after bookkeeping.
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids
(
valid_sampled_token_ids
)
propose_draft_token_ids
(
valid_sampled_token_ids
)
with
record_function_or_nullcontext
(
"
EPLB
"
):
with
record_function_or_nullcontext
(
"
gpu_model_runner: eplb
"
):
self
.
eplb_step
()
self
.
eplb_step
()
with
record_function_or_nullcontext
(
"gpu_model_runner: ModelRunnerOutput"
):
output
=
ModelRunnerOutput
(
output
=
ModelRunnerOutput
(
req_ids
=
req_ids_output_copy
,
req_ids
=
req_ids_output_copy
,
req_id_to_index
=
req_id_to_index_output_copy
,
req_id_to_index
=
req_id_to_index_output_copy
,
sampled_token_ids
=
valid_sampled_token_ids
,
sampled_token_ids
=
valid_sampled_token_ids
,
logprobs
=
logprobs_lists
,
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
pooler_output
=
[],
pooler_output
=
[],
kv_connector_output
=
kv_connector_output
,
kv_connector_output
=
kv_connector_output
,
num_nans_in_logits
=
num_nans_in_logits
,
num_nans_in_logits
=
num_nans_in_logits
,
)
)
if
not
self
.
use_async_scheduling
:
if
not
self
.
use_async_scheduling
:
return
output
return
output
with
record_function_or_nullcontext
(
async_output
=
AsyncGPUModelRunnerOutput
(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
model_runner_output
=
output
,
):
sampled_token_ids
=
sampler_output
.
sampled_token_ids
,
async_output
=
AsyncGPUModelRunnerOutput
(
logprobs_tensors
=
sampler_output
.
logprobs_tensors
,
model_runner_output
=
output
,
invalid_req_indices
=
invalid_req_indices
,
sampled_token_ids
=
sampler_output
.
sampled_token_ids
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
logprobs_tensors
=
sampler_output
.
logprobs_tensors
,
)
invalid_req_indices
=
invalid_req_indices
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
# Save ref of sampled_token_ids CPU tensor if the batch contains
)
# any requests with sampling params that that require output ids.
with
record_function_or_nullcontext
(
self
.
input_batch
.
set_async_sampled_token_ids
(
"gpu_model_runner: set_async_sampled_token_ids"
async_output
.
sampled_token_ids_cpu
,
):
async_output
.
async_copy_ready_event
,
# Save ref of sampled_token_ids CPU tensor if the batch contains
)
# any requests with sampling params that that require output ids.
self
.
input_batch
.
set_async_sampled_token_ids
(
async_output
.
sampled_token_ids_cpu
,
async_output
.
async_copy_ready_event
,
)
return
async_output
return
async_output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment