Unverified Commit d16862ad authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

chore: add context manager based timer (#7007)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 82f60cc7
......@@ -24,6 +24,7 @@ from dynamo.common.utils import (
paths,
prometheus,
runtime,
time_section,
)
__all__ = [
......@@ -32,6 +33,7 @@ __all__ = [
"namespace",
"nvtx_utils",
"otel_tracing",
"time_section",
"paths",
"prometheus",
"runtime",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import time
from contextlib import contextmanager
from typing import Callable
logger = logging.getLogger(__name__)
DEFAULT_LOG_LEVEL = logging.DEBUG
class Timer:
"""Simple timer implementation that can time interval, on constructing the Timer,
it starts the timer immediately, which will be stopped when calling stop().
It also supports timing intervals by calling start_interval() and stop_interval(),
so that you can time different parts of the code while the timer is running, note
that stop_interval() will update the interval start time to current time.
Example guide:
t = Timer(
lambda elapsed: logger.debug(f"phase1: {elapsed:.2f}s"),
lambda total: logger.debug(f"total: {total:.2f}s"),
)
t.start_interval()
do_phase1()
t.stop_interval() # prefer "start_something"
do_phase2()
t.stop()
"""
def __init__(
self,
interval_func: Callable[[float], None] = None,
stop_func: Callable[[float], None] = None,
):
"""Initialize the Timer and start timing immediately.
Args:
interval_func: Optional callback invoked with elapsed seconds during the timed interval when stop_interval() is called.
stop_func: Optional callback invoked with total elapsed seconds when stop() is called.
"""
self.start_time = time.perf_counter()
self.interval_start_time = self.start_time
self.interval_func = interval_func
self.stop_func = stop_func
def start_interval(self):
"""Start the interval timer."""
self.interval_start_time = time.perf_counter()
def stop_interval(self):
"""Stop the interval timer and return the elapsed time during the interval."""
now = time.perf_counter()
interval_time = now - self.interval_start_time
if self.interval_func:
self.interval_func(interval_time)
self.interval_start_time = now
return interval_time
def stop(self):
"""Stop the timer and return the total elapsed time."""
total_time = time.perf_counter() - self.start_time
if self.stop_func:
self.stop_func(total_time)
return total_time
@contextmanager
def time_and_log_code_section(log_message: str, log_level=DEFAULT_LOG_LEVEL):
"""Context manager that times a code block and logs total elapsed on exit.
Use the yielded timer's start_interval() and stop_interval() inside the block to
log sub-intervals (e.g. time to first token). Total elapsed is always logged
when the block exits.
Example:
with time_and_log_code_section("[DECODE] generate") as t:
t.start_interval()
async for chunk in generator():
if first_token:
t.stop_interval() # Log time to first chunk
first_token = False
yield chunk
Expected output (at default log level DEBUG):
[DECODE] generate - interval 0.1234 seconds # if stop_interval() was called
[DECODE] generate - total elapsed 1.5678 seconds # always on exit
Args:
log_message: Base message to use for logging, interval and total times will be appended.
log_level: Logging level to use for the messages (default: logging.DEBUG).
"""
timer = Timer(
lambda elapsed: logger.log(
log_level, f"{log_message} - interval {elapsed:.4f} seconds"
),
lambda total: logger.log(
log_level, f"{log_message} - total elapsed {total:.4f} seconds"
),
)
try:
yield timer
finally:
timer.stop()
......@@ -28,6 +28,7 @@ from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.llm import (
KvEventPublisher,
ModelInput,
......@@ -1314,14 +1315,21 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Use context ID for request tracking and correlation
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")
first_token = True
with time_and_log_code_section(
f"[DECODE] request: {request_id} generate"
) as decode_timer:
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
generator = self._generate_text_mode(request, context, request_id)
else:
# Token-in-token-out mode: internal protocol format
generator = self._generate_token_mode(request, context, request_id)
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
async for chunk in self._generate_text_mode(request, context, request_id):
yield chunk
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
async for chunk in generator:
if first_token:
decode_timer.stop_interval()
first_token = False
yield chunk
async def _generate_token_mode(self, request, context, request_id):
......@@ -1524,8 +1532,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
logger.debug(f"Prefill Request ID: {request_id}")
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
with time_and_log_code_section(f"[PREFILL] request: {request_id} generate"):
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out)."""
......
......@@ -19,6 +19,7 @@ from dynamo.common.multimodal import (
NixlWriteEmbeddingSender,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode
......@@ -177,7 +178,11 @@ class EncodeWorkerHandler:
# keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key))
with _nvtx.annotate("mm:enc:image_load", color="green"):
with _nvtx.annotate(
"mm:enc:image_load", color="green"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image loading"
):
# Load and generate image tensors
image_tasks = []
image_to_load = []
......@@ -207,12 +212,20 @@ class EncodeWorkerHandler:
)
if loaded_images:
with _nvtx.annotate("mm:enc:image_preprocess", color="yellow"):
with _nvtx.annotate(
"mm:enc:image_preprocess", color="yellow"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image processing"
):
image_embeds = await asyncio.to_thread(
self.image_processor, images=loaded_images, return_tensors="pt"
)
with _nvtx.annotate("mm:enc:vision_encode", color="red"):
with _nvtx.annotate(
"mm:enc:vision_encode", color="red"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} encoding"
):
# Encode the image embeddings using model-specific encoder
embeddings = await asyncio.to_thread(
encode_image_embeddings,
......
......@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver,
)
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client, DistributedRuntime
from ..args import Config
......@@ -303,17 +304,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_ttft=None,
):
"""Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request
prefill_only_request = copy.deepcopy(request)
extra_args = prefill_only_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {"do_remote_decode": True}
prefill_only_request.sampling_params.extra_args = extra_args
prefill_only_request.sampling_params.max_tokens = 1
prefill_only_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
with _nvtx.annotate("mm:pd:disagg_prefill", color="darkred"):
with _nvtx.annotate(
"mm:pd:disagg_prefill", color="darkred"
), time_and_log_code_section(
f"[PREFILL] request: {request.request_id} prefill time"
):
# Prepare prefill-only request
prefill_only_request = copy.deepcopy(request)
extra_args = prefill_only_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {"do_remote_decode": True}
prefill_only_request.sampling_params.extra_args = extra_args
prefill_only_request.sampling_params.max_tokens = 1
prefill_only_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt[
......@@ -367,7 +372,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"— ensure the same adapter is loaded on the decode worker."
)
with _nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"):
with (
_nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"),
time_and_log_code_section(
f"[PREFILL] request: {request.request_id} remote decode time"
) as decode_timer,
):
num_output_tokens_so_far = 0
async for (
decode_response
......@@ -377,6 +387,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs:
if num_output_tokens_so_far == 0:
decode_timer.stop_interval() # Log time to first decode response
num_output_tokens_so_far = len(output.outputs[0].token_ids)
# ── Public entry point ───────────────────────────────────────────
......@@ -386,18 +398,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_pd = _nvtx.start_range("mm:pd_worker_generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="orange")
rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan")
request, image_urls = self._parse_frontend_request(raw_request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
_nvtx.end_range(rng_parse)
with time_and_log_code_section("[REQUEST] embedding processing time"):
rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan")
request, image_urls = self._parse_frontend_request(raw_request)
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
_nvtx.end_range(rng_parse)
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id
)
_nvtx.end_range(rng_load)
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id
)
_nvtx.end_range(rng_load)
self._finalize_request_metadata(request, multi_modal_data)
self._finalize_request_metadata(request, multi_modal_data)
if self.enable_disagg and self.decode_worker_client:
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
......
......@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
from ..args import Config
......@@ -56,68 +57,75 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} preprocessing time"
):
if not isinstance(request, vLLMMultimodalRequest):
if isinstance(request, str):
request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
# image_grid_thw for position embeddings calculation. The decode worker
# receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM
# will expand it using the multi_modal_data, ensuring the block count
# matches what prefill computed.
#
# We pass unique placeholder embeddings (seeded by request_id) since the
# actual embeddings are already in the KV cache from prefill. The unique
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
image_grid_thw = getattr(request, "image_grid_thw", None)
embeddings_shape = getattr(request, "embeddings_shape", None)
if image_grid_thw is None or embeddings_shape is None:
logger.warning(
"Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); "
"skipping multi_modal_data construction."
)
else:
multi_modal_data = construct_qwen_decode_mm_data(
image_grid_thw, embeddings_shape, request.request_id
)
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
# image_grid_thw for position embeddings calculation. The decode worker
# receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM
# will expand it using the multi_modal_data, ensuring the block count
# matches what prefill computed.
#
# We pass unique placeholder embeddings (seeded by request_id) since the
# actual embeddings are already in the KV cache from prefill. The unique
# values prevent incorrect prefix cache matches between different images.
multi_modal_data = None
if is_qwen_vl_model(self.config.model):
image_grid_thw = getattr(request, "image_grid_thw", None)
embeddings_shape = getattr(request, "embeddings_shape", None)
if image_grid_thw is None or embeddings_shape is None:
logger.warning(
"Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); "
"skipping multi_modal_data construction."
)
else:
multi_modal_data = construct_qwen_decode_mm_data(
image_grid_thw, embeddings_shape, request.request_id
)
lora_request = self._resolve_lora_request(request.model)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
)
with time_and_log_code_section(
f"[DECODE] request: {request.request_id} generate time"
) as gen_timer:
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
multi_modal_data=multi_modal_data,
),
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
)
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
first_token = True
try:
async for response in gen:
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
first_token = True
try:
async for response in gen:
if first_token:
gen_timer.stop_interval() # Log time to first response
_nvtx.end_range(rng_first)
first_token = False
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
finally:
if first_token:
_nvtx.end_range(rng_first)
first_token = False
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
finally:
if first_token:
_nvtx.end_range(rng_first)
_nvtx.end_range(rng_decode)
_nvtx.end_range(rng_decode)
......@@ -18,6 +18,7 @@ from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
)
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client
from .encode_utils import get_embedding_hash
......@@ -163,41 +164,48 @@ async def _fetch_from_encode_workers(
multimodal_inputs=[],
)
batch: List[MultiModalGroup] = []
encode_response_streams = []
for url in image_urls:
multimodal_input = MultiModalInput()
multimodal_input.image_url = url
batch.append(MultiModalGroup(multimodal_input=multimodal_input))
with time_and_log_code_section(f"[PREFILL] request: {request_id} dispatch encode"):
batch: List[MultiModalGroup] = []
encode_response_streams = []
for url in image_urls:
multimodal_input = MultiModalInput()
multimodal_input.image_url = url
batch.append(MultiModalGroup(multimodal_input=multimodal_input))
if len(batch) >= encode_batch_size:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
batch = []
if len(batch) >= encode_batch_size:
if batch:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
batch = []
if batch:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
]
loaded = await asyncio.gather(*tasks)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} receive encode responses"
):
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} receive embeddings"
):
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
]
loaded = await asyncio.gather(*tasks)
is_local = isinstance(receiver, LocalEmbeddingReceiver)
pending: _PendingRelease | None = None if is_local else _PendingRelease(receiver)
......@@ -302,15 +310,18 @@ async def load_multimodal_embeddings(
)
multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups:
assert group.loaded_embedding is not None
_accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} accumulate embeddings"
):
for group in groups:
assert group.loaded_embedding is not None
_accumulate_embeddings(
multi_modal_data,
model,
embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
if pending is not None:
# Multi-image: torch.cat in _accumulate_embeddings already created
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment