Unverified Commit bf6a3d0f authored by Wei Wei's avatar Wei Wei Committed by GitHub
Browse files

[Misc] Add more scoping for improved trace (#28329)


Signed-off-by: default avatarWei Wei <wwei6@meta.com>
parent 40d33264
......@@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
logger = init_logger(__name__)
......@@ -259,6 +260,7 @@ class Scheduler(SchedulerInterface):
continue
# Schedule newly needed KV blocks for the request.
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -280,7 +282,9 @@ class Scheduler(SchedulerInterface):
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
token_budget += num_scheduled_tokens[
preempted_req.request_id
]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
......@@ -599,6 +603,7 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
......@@ -614,6 +619,7 @@ class Scheduler(SchedulerInterface):
)
for req in scheduled_new_reqs
]
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
......@@ -649,7 +655,7 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
......
......@@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
......@@ -315,7 +316,10 @@ class EngineCore:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
with record_function_or_nullcontext("core step: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext("core step: execute_model"):
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
......@@ -323,6 +327,7 @@ class EngineCore:
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
with record_function_or_nullcontext("core step: update_from_output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
......@@ -363,28 +368,45 @@ class EngineCore:
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext(
"core step_with_batch_queue: execute_model"
):
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if scheduler_output.pending_structured_output_tokens:
with record_function_or_nullcontext(
"core step_with_batch_queue: pending_structured_output_tokens"
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
# Block-wait for execute to return
# (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with record_function_or_nullcontext(
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
with record_function_or_nullcontext(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
......@@ -408,12 +430,14 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
with record_function_or_nullcontext(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
......@@ -422,12 +446,17 @@ class EngineCore:
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
with record_function_or_nullcontext(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
......
......@@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
......@@ -280,9 +281,11 @@ class LLMEngine:
return []
# 1) Get EngineCoreOutput from the EngineCore.
with record_function_or_nullcontext("llm_genine step: get_output"):
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
with record_function_or_nullcontext("llm_genine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
......@@ -292,9 +295,11 @@ class LLMEngine:
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
with record_function_or_nullcontext("llm_genine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
with record_function_or_nullcontext("llm_genine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
......
......@@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
......@@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
),
record_function_or_nullcontext("Forward"),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
):
model_output = self._model_forward(
......@@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**model_kwargs,
)
with record_function_or_nullcontext("Postprocess"):
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
hidden_states, aux_hidden_states = model_output
......@@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
sampled_token_ids,
......@@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
with record_function_or_nullcontext("Bookkeep"):
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
num_nans_in_logits,
logprobs_lists,
......@@ -2826,9 +2826,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
with record_function_or_nullcontext("EPLB"):
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
......@@ -2842,7 +2842,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if not self.use_async_scheduling:
return output
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
......@@ -2850,7 +2852,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
with record_function_or_nullcontext(
"gpu_model_runner: set_async_sampled_token_ids"
):
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment