Unverified Commit d16862ad authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

chore: add context manager based timer (#7007)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 82f60cc7
......@@ -24,6 +24,7 @@ from dynamo.common.utils import (
paths,
prometheus,
runtime,
time_section,
)
__all__ = [
......@@ -32,6 +33,7 @@ __all__ = [
"namespace",
"nvtx_utils",
"otel_tracing",
"time_section",
"paths",
"prometheus",
"runtime",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import time
from contextlib import contextmanager
from typing import Callable
logger = logging.getLogger(__name__)
DEFAULT_LOG_LEVEL = logging.DEBUG
class Timer:
"""Simple timer implementation that can time interval, on constructing the Timer,
it starts the timer immediately, which will be stopped when calling stop().
It also supports timing intervals by calling start_interval() and stop_interval(),
so that you can time different parts of the code while the timer is running, note
that stop_interval() will update the interval start time to current time.
Example guide:
t = Timer(
lambda elapsed: logger.debug(f"phase1: {elapsed:.2f}s"),
lambda total: logger.debug(f"total: {total:.2f}s"),
)
t.start_interval()
do_phase1()
t.stop_interval() # prefer "start_something"
do_phase2()
t.stop()
"""
def __init__(
self,
interval_func: Callable[[float], None] = None,
stop_func: Callable[[float], None] = None,
):
"""Initialize the Timer and start timing immediately.
Args:
interval_func: Optional callback invoked with elapsed seconds during the timed interval when stop_interval() is called.
stop_func: Optional callback invoked with total elapsed seconds when stop() is called.
"""
self.start_time = time.perf_counter()
self.interval_start_time = self.start_time
self.interval_func = interval_func
self.stop_func = stop_func
def start_interval(self):
"""Start the interval timer."""
self.interval_start_time = time.perf_counter()
def stop_interval(self):
"""Stop the interval timer and return the elapsed time during the interval."""
now = time.perf_counter()
interval_time = now - self.interval_start_time
if self.interval_func:
self.interval_func(interval_time)
self.interval_start_time = now
return interval_time
def stop(self):
"""Stop the timer and return the total elapsed time."""
total_time = time.perf_counter() - self.start_time
if self.stop_func:
self.stop_func(total_time)
return total_time
@contextmanager
def time_and_log_code_section(log_message: str, log_level=DEFAULT_LOG_LEVEL):
"""Context manager that times a code block and logs total elapsed on exit.
Use the yielded timer's start_interval() and stop_interval() inside the block to
log sub-intervals (e.g. time to first token). Total elapsed is always logged
when the block exits.
Example:
with time_and_log_code_section("[DECODE] generate") as t:
t.start_interval()
async for chunk in generator():
if first_token:
t.stop_interval() # Log time to first chunk
first_token = False
yield chunk
Expected output (at default log level DEBUG):
[DECODE] generate - interval 0.1234 seconds # if stop_interval() was called
[DECODE] generate - total elapsed 1.5678 seconds # always on exit
Args:
log_message: Base message to use for logging, interval and total times will be appended.
log_level: Logging level to use for the messages (default: logging.DEBUG).
"""
timer = Timer(
lambda elapsed: logger.log(
log_level, f"{log_message} - interval {elapsed:.4f} seconds"
),
lambda total: logger.log(
log_level, f"{log_message} - total elapsed {total:.4f} seconds"
),
)
try:
yield timer
finally:
timer.stop()
......@@ -28,6 +28,7 @@ from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.llm import (
KvEventPublisher,
ModelInput,
......@@ -1314,14 +1315,21 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Use context ID for request tracking and correlation
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")
first_token = True
with time_and_log_code_section(
f"[DECODE] request: {request_id} generate"
) as decode_timer:
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
generator = self._generate_text_mode(request, context, request_id)
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
generator = self._generate_token_mode(request, context, request_id)
async for chunk in generator:
if first_token:
decode_timer.stop_interval()
first_token = False
yield chunk
async def _generate_token_mode(self, request, context, request_id):
......@@ -1524,6 +1532,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
logger.debug(f"Prefill Request ID: {request_id}")
# Token-in-token-out mode: internal protocol format
with time_and_log_code_section(f"[PREFILL] request: {request_id} generate"):
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
......
......@@ -19,6 +19,7 @@ from dynamo.common.multimodal import (
NixlWriteEmbeddingSender,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode
......@@ -177,7 +178,11 @@ class EncodeWorkerHandler:
# keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key))
with _nvtx.annotate("mm:enc:image_load", color="green"):
with _nvtx.annotate(
"mm:enc:image_load", color="green"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image loading"
):
# Load and generate image tensors
image_tasks = []
image_to_load = []
......@@ -207,12 +212,20 @@ class EncodeWorkerHandler:
)
if loaded_images:
with _nvtx.annotate("mm:enc:image_preprocess", color="yellow"):
with _nvtx.annotate(
"mm:enc:image_preprocess", color="yellow"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image processing"
):
image_embeds = await asyncio.to_thread(
self.image_processor, images=loaded_images, return_tensors="pt"
)
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
with _nvtx.annotate(
"mm:enc:vision_encode", color="red"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} encoding"
):
# Encode the image embeddings using model-specific encoder
embeddings = await asyncio.to_thread(
encode_image_embeddings,
......
......@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client, DistributedRuntime
from ..args import Config
......@@ -303,6 +304,11 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_ttft=None,
):
"""Prefill locally, then forward to a remote decode worker."""
with _nvtx.annotate(
"mm:pd:disagg_prefill", color="darkred"
), time_and_log_code_section(
f"[PREFILL] request: {request.request_id} prefill time"
):
# Prepare prefill-only request
prefill_only_request = copy.deepcopy(request)
extra_args = prefill_only_request.sampling_params.extra_args or {}
......@@ -313,7 +319,6 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
with _nvtx.annotate("mm:pd:disagg_prefill", color="darkred"):
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt[
......@@ -367,7 +372,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"— ensure the same adapter is loaded on the decode worker."
)
with _nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"):
with (
_nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"),
time_and_log_code_section(
f"[PREFILL] request: {request.request_id} remote decode time"
) as decode_timer,
):
num_output_tokens_so_far = 0
async for (
decode_response
......@@ -377,6 +387,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs:
if num_output_tokens_so_far == 0:
decode_timer.stop_interval() # Log time to first decode response
num_output_tokens_so_far = len(output.outputs[0].token_ids)
# ── Public entry point ───────────────────────────────────────────
......@@ -386,6 +398,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_pd = _nvtx.start_range("mm:pd_worker_generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="orange")
with time_and_log_code_section("[REQUEST] embedding processing time"):
rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan")
request, image_urls = self._parse_frontend_request(raw_request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
......
......@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
from ..args import Config
......@@ -56,6 +57,9 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}")
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} preprocessing time"
):
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
......@@ -85,8 +89,11 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
multi_modal_data = construct_qwen_decode_mm_data(
image_grid_thw, embeddings_shape, request.request_id
)
lora_request = self._resolve_lora_request(request.model)
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} generate time"
) as gen_timer:
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
......@@ -102,6 +109,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
try:
async for response in gen:
if first_token:
gen_timer.stop_interval() # Log time to first response
_nvtx.end_range(rng_first)
first_token = False
logger.debug(
......
......@@ -18,6 +18,7 @@ from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
)
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client
from .encode_utils import get_embedding_hash
......@@ -163,6 +164,7 @@ async def _fetch_from_encode_workers(
multimodal_inputs=[],
)
with time_and_log_code_section(f"[PREFILL] request: {request_id} dispatch encode"):
batch: List[MultiModalGroup] = []
encode_response_streams = []
for url in image_urls:
......@@ -185,6 +187,9 @@ async def _fetch_from_encode_workers(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} receive encode responses"
):
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
......@@ -193,6 +198,9 @@ async def _fetch_from_encode_workers(
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} receive embeddings"
):
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
......@@ -302,6 +310,9 @@ async def load_multimodal_embeddings(
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} accumulate embeddings"
):
for group in groups:
assert group.loaded_embedding is not None
_accumulate_embeddings(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment