Unverified Commit d16862ad authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

chore: add context manager based timer (#7007)


Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
Signed-off-by: default avatarGuanLuo <41310872+GuanLuo@users.noreply.github.com>
parent 82f60cc7
...@@ -24,6 +24,7 @@ from dynamo.common.utils import ( ...@@ -24,6 +24,7 @@ from dynamo.common.utils import (
paths, paths,
prometheus, prometheus,
runtime, runtime,
time_section,
) )
__all__ = [ __all__ = [
...@@ -32,6 +33,7 @@ __all__ = [ ...@@ -32,6 +33,7 @@ __all__ = [
"namespace", "namespace",
"nvtx_utils", "nvtx_utils",
"otel_tracing", "otel_tracing",
"time_section",
"paths", "paths",
"prometheus", "prometheus",
"runtime", "runtime",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import time
from contextlib import contextmanager
from typing import Callable
logger = logging.getLogger(__name__)
DEFAULT_LOG_LEVEL = logging.DEBUG
class Timer:
"""Simple timer implementation that can time interval, on constructing the Timer,
it starts the timer immediately, which will be stopped when calling stop().
It also supports timing intervals by calling start_interval() and stop_interval(),
so that you can time different parts of the code while the timer is running, note
that stop_interval() will update the interval start time to current time.
Example guide:
t = Timer(
lambda elapsed: logger.debug(f"phase1: {elapsed:.2f}s"),
lambda total: logger.debug(f"total: {total:.2f}s"),
)
t.start_interval()
do_phase1()
t.stop_interval() # prefer "start_something"
do_phase2()
t.stop()
"""
def __init__(
self,
interval_func: Callable[[float], None] = None,
stop_func: Callable[[float], None] = None,
):
"""Initialize the Timer and start timing immediately.
Args:
interval_func: Optional callback invoked with elapsed seconds during the timed interval when stop_interval() is called.
stop_func: Optional callback invoked with total elapsed seconds when stop() is called.
"""
self.start_time = time.perf_counter()
self.interval_start_time = self.start_time
self.interval_func = interval_func
self.stop_func = stop_func
def start_interval(self):
"""Start the interval timer."""
self.interval_start_time = time.perf_counter()
def stop_interval(self):
"""Stop the interval timer and return the elapsed time during the interval."""
now = time.perf_counter()
interval_time = now - self.interval_start_time
if self.interval_func:
self.interval_func(interval_time)
self.interval_start_time = now
return interval_time
def stop(self):
"""Stop the timer and return the total elapsed time."""
total_time = time.perf_counter() - self.start_time
if self.stop_func:
self.stop_func(total_time)
return total_time
@contextmanager
def time_and_log_code_section(log_message: str, log_level=DEFAULT_LOG_LEVEL):
"""Context manager that times a code block and logs total elapsed on exit.
Use the yielded timer's start_interval() and stop_interval() inside the block to
log sub-intervals (e.g. time to first token). Total elapsed is always logged
when the block exits.
Example:
with time_and_log_code_section("[DECODE] generate") as t:
t.start_interval()
async for chunk in generator():
if first_token:
t.stop_interval() # Log time to first chunk
first_token = False
yield chunk
Expected output (at default log level DEBUG):
[DECODE] generate - interval 0.1234 seconds # if stop_interval() was called
[DECODE] generate - total elapsed 1.5678 seconds # always on exit
Args:
log_message: Base message to use for logging, interval and total times will be appended.
log_level: Logging level to use for the messages (default: logging.DEBUG).
"""
timer = Timer(
lambda elapsed: logger.log(
log_level, f"{log_message} - interval {elapsed:.4f} seconds"
),
lambda total: logger.log(
log_level, f"{log_message} - total elapsed {total:.4f} seconds"
),
)
try:
yield timer
finally:
timer.stop()
...@@ -28,6 +28,7 @@ from dynamo.common.multimodal.image_loader import ImageLoader ...@@ -28,6 +28,7 @@ from dynamo.common.multimodal.image_loader import ImageLoader
from dynamo.common.utils.engine_response import normalize_finish_reason from dynamo.common.utils.engine_response import normalize_finish_reason
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.otel_tracing import build_trace_headers from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.llm import ( from dynamo.llm import (
KvEventPublisher, KvEventPublisher,
ModelInput, ModelInput,
...@@ -1314,14 +1315,21 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1314,14 +1315,21 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Use context ID for request tracking and correlation # Use context ID for request tracking and correlation
request_id = context.id() request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}") logger.debug(f"Decode Request ID: {request_id}")
first_token = True
with time_and_log_code_section(
f"[DECODE] request: {request_id} generate"
) as decode_timer:
if self.use_vllm_tokenizer:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format
generator = self._generate_text_mode(request, context, request_id)
else:
# Token-in-token-out mode: internal protocol format
generator = self._generate_token_mode(request, context, request_id)
if self.use_vllm_tokenizer: async for chunk in generator:
# Text-in-text-out mode: use InputParamManager and OpenAI-compatible format if first_token:
async for chunk in self._generate_text_mode(request, context, request_id): decode_timer.stop_interval()
yield chunk first_token = False
else:
# Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk yield chunk
async def _generate_token_mode(self, request, context, request_id): async def _generate_token_mode(self, request, context, request_id):
...@@ -1524,8 +1532,9 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1524,8 +1532,9 @@ class PrefillWorkerHandler(BaseWorkerHandler):
logger.debug(f"Prefill Request ID: {request_id}") logger.debug(f"Prefill Request ID: {request_id}")
# Token-in-token-out mode: internal protocol format # Token-in-token-out mode: internal protocol format
async for chunk in self._generate_token_mode(request, context, request_id): with time_and_log_code_section(f"[PREFILL] request: {request_id} generate"):
yield chunk async for chunk in self._generate_token_mode(request, context, request_id):
yield chunk
async def _generate_token_mode(self, request, context, request_id): async def _generate_token_mode(self, request, context, request_id):
"""Generate prefill using internal protocol format (token-in-token-out).""" """Generate prefill using internal protocol format (token-in-token-out)."""
......
...@@ -19,6 +19,7 @@ from dynamo.common.multimodal import ( ...@@ -19,6 +19,7 @@ from dynamo.common.multimodal import (
NixlWriteEmbeddingSender, NixlWriteEmbeddingSender,
) )
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..constants import EmbeddingTransferMode from ..constants import EmbeddingTransferMode
...@@ -177,7 +178,11 @@ class EncodeWorkerHandler: ...@@ -177,7 +178,11 @@ class EncodeWorkerHandler:
# keep track of key to avoid recompute of it # keep track of key to avoid recompute of it
need_encode_indexes.append((idx, embedding_key)) need_encode_indexes.append((idx, embedding_key))
with _nvtx.annotate("mm:enc:image_load", color="green"): with _nvtx.annotate(
"mm:enc:image_load", color="green"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image loading"
):
# Load and generate image tensors # Load and generate image tensors
image_tasks = [] image_tasks = []
image_to_load = [] image_to_load = []
...@@ -207,12 +212,20 @@ class EncodeWorkerHandler: ...@@ -207,12 +212,20 @@ class EncodeWorkerHandler:
) )
if loaded_images: if loaded_images:
with _nvtx.annotate("mm:enc:image_preprocess", color="yellow"): with _nvtx.annotate(
"mm:enc:image_preprocess", color="yellow"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} image processing"
):
image_embeds = await asyncio.to_thread( image_embeds = await asyncio.to_thread(
self.image_processor, images=loaded_images, return_tensors="pt" self.image_processor, images=loaded_images, return_tensors="pt"
) )
with _nvtx.annotate("mm:enc:vision_encode", color="red"): with _nvtx.annotate(
"mm:enc:vision_encode", color="red"
), time_and_log_code_section(
f"[ENCODE] request: {request_id} encoding"
):
# Encode the image embeddings using model-specific encoder # Encode the image embeddings using model-specific encoder
embeddings = await asyncio.to_thread( embeddings = await asyncio.to_thread(
encode_image_embeddings, encode_image_embeddings,
......
...@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -21,6 +21,7 @@ from dynamo.common.multimodal.embedding_transfer import (
NixlWriteEmbeddingReceiver, NixlWriteEmbeddingReceiver,
) )
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client, DistributedRuntime from dynamo.runtime import Client, DistributedRuntime
from ..args import Config from ..args import Config
...@@ -303,17 +304,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -303,17 +304,21 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_ttft=None, rng_ttft=None,
): ):
"""Prefill locally, then forward to a remote decode worker.""" """Prefill locally, then forward to a remote decode worker."""
# Prepare prefill-only request with _nvtx.annotate(
prefill_only_request = copy.deepcopy(request) "mm:pd:disagg_prefill", color="darkred"
extra_args = prefill_only_request.sampling_params.extra_args or {} ), time_and_log_code_section(
extra_args["kv_transfer_params"] = {"do_remote_decode": True} f"[PREFILL] request: {request.request_id} prefill time"
prefill_only_request.sampling_params.extra_args = extra_args ):
prefill_only_request.sampling_params.max_tokens = 1 # Prepare prefill-only request
prefill_only_request.sampling_params.min_tokens = 1 prefill_only_request = copy.deepcopy(request)
logger.debug("Prefill request: %s", prefill_only_request) extra_args = prefill_only_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {"do_remote_decode": True}
lora_request = self._resolve_lora_request(request.model) prefill_only_request.sampling_params.extra_args = extra_args
with _nvtx.annotate("mm:pd:disagg_prefill", color="darkred"): prefill_only_request.sampling_params.max_tokens = 1
prefill_only_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt[ prompt_token_ids=prefill_only_request.engine_prompt[
...@@ -367,7 +372,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -367,7 +372,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
f"— ensure the same adapter is loaded on the decode worker." f"— ensure the same adapter is loaded on the decode worker."
) )
with _nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"): with (
_nvtx.annotate("mm:pd:disagg_remote_decode", color="purple"),
time_and_log_code_section(
f"[PREFILL] request: {request.request_id} remote decode time"
) as decode_timer,
):
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for ( async for (
decode_response decode_response
...@@ -377,6 +387,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -377,6 +387,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far) yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs: if output.outputs:
if num_output_tokens_so_far == 0:
decode_timer.stop_interval() # Log time to first decode response
num_output_tokens_so_far = len(output.outputs[0].token_ids) num_output_tokens_so_far = len(output.outputs[0].token_ids)
# ── Public entry point ─────────────────────────────────────────── # ── Public entry point ───────────────────────────────────────────
...@@ -386,18 +398,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -386,18 +398,19 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
rng_pd = _nvtx.start_range("mm:pd_worker_generate", color="green") rng_pd = _nvtx.start_range("mm:pd_worker_generate", color="green")
rng_ttft = _nvtx.start_range("mm:pd:ttft", color="orange") rng_ttft = _nvtx.start_range("mm:pd:ttft", color="orange")
rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan") with time_and_log_code_section("[REQUEST] embedding processing time"):
request, image_urls = self._parse_frontend_request(raw_request) rng_parse = _nvtx.start_range("mm:pd:parse_request", color="cyan")
logger.debug(f"Received PD request: {{ id: {request.request_id} }}.") request, image_urls = self._parse_frontend_request(raw_request)
_nvtx.end_range(rng_parse) logger.debug(f"Received PD request: {{ id: {request.request_id} }}.")
_nvtx.end_range(rng_parse)
rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow") rng_load = _nvtx.start_range("mm:pd:load_multimodal", color="yellow")
multi_modal_data = await self._load_multimodal_data( multi_modal_data = await self._load_multimodal_data(
image_urls, request.request_id image_urls, request.request_id
) )
_nvtx.end_range(rng_load) _nvtx.end_range(rng_load)
self._finalize_request_metadata(request, multi_modal_data) self._finalize_request_metadata(request, multi_modal_data)
if self.enable_disagg and self.decode_worker_client: if self.enable_disagg and self.decode_worker_client:
rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red") rng_disagg = _nvtx.start_range("mm:pd:generate_disagg", color="red")
......
...@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt ...@@ -7,6 +7,7 @@ from vllm.inputs.data import TokensPrompt
import dynamo.nixl_connect as connect import dynamo.nixl_connect as connect
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from ..args import Config from ..args import Config
...@@ -56,68 +57,75 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -56,68 +57,75 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(self, request: vLLMMultimodalRequest, context):
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue") rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}") logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest): with time_and_log_code_section(
if isinstance(request, str): f"[DECODE] request: {request.request_id} preprocessing time"
request = vLLMMultimodalRequest.model_validate_json(request) ):
else: if not isinstance(request, vLLMMultimodalRequest):
request = vLLMMultimodalRequest.model_validate(request) if isinstance(request, str):
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.") request = vLLMMultimodalRequest.model_validate_json(request)
else:
request = vLLMMultimodalRequest.model_validate(request)
logger.debug(f"Received decode request: {{ id: {request.request_id} }}.")
# For Qwen VL models with mRoPE, we need to pass multi_modal_data containing # For Qwen VL models with mRoPE, we need to pass multi_modal_data containing
# image_grid_thw for position embeddings calculation. The decode worker # image_grid_thw for position embeddings calculation. The decode worker
# receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM # receives the ORIGINAL unexpanded prompt (with placeholders), and vLLM
# will expand it using the multi_modal_data, ensuring the block count # will expand it using the multi_modal_data, ensuring the block count
# matches what prefill computed. # matches what prefill computed.
# #
# We pass unique placeholder embeddings (seeded by request_id) since the # We pass unique placeholder embeddings (seeded by request_id) since the
# actual embeddings are already in the KV cache from prefill. The unique # actual embeddings are already in the KV cache from prefill. The unique
# values prevent incorrect prefix cache matches between different images. # values prevent incorrect prefix cache matches between different images.
multi_modal_data = None multi_modal_data = None
if is_qwen_vl_model(self.config.model): if is_qwen_vl_model(self.config.model):
image_grid_thw = getattr(request, "image_grid_thw", None) image_grid_thw = getattr(request, "image_grid_thw", None)
embeddings_shape = getattr(request, "embeddings_shape", None) embeddings_shape = getattr(request, "embeddings_shape", None)
if image_grid_thw is None or embeddings_shape is None: if image_grid_thw is None or embeddings_shape is None:
logger.warning( logger.warning(
"Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); " "Missing Qwen VL decode fields (image_grid_thw/embeddings_shape); "
"skipping multi_modal_data construction." "skipping multi_modal_data construction."
) )
else: else:
multi_modal_data = construct_qwen_decode_mm_data( multi_modal_data = construct_qwen_decode_mm_data(
image_grid_thw, embeddings_shape, request.request_id image_grid_thw, embeddings_shape, request.request_id
) )
lora_request = self._resolve_lora_request(request.model)
lora_request = self._resolve_lora_request(request.model) with time_and_log_code_section(
gen = self.engine_client.generate( f"[DECODE] request: {request.request_id} generate time"
prompt=TokensPrompt( ) as gen_timer:
prompt_token_ids=request.engine_prompt["prompt_token_ids"], gen = self.engine_client.generate(
multi_modal_data=multi_modal_data, prompt=TokensPrompt(
), prompt_token_ids=request.engine_prompt["prompt_token_ids"],
sampling_params=request.sampling_params, multi_modal_data=multi_modal_data,
request_id=request.request_id, ),
lora_request=lora_request, sampling_params=request.sampling_params,
) request_id=request.request_id,
lora_request=lora_request,
)
rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred") rng_first = _nvtx.start_range("mm:decode:first_token", color="darkred")
first_token = True first_token = True
try: try:
async for response in gen: async for response in gen:
if first_token:
gen_timer.stop_interval() # Log time to first response
_nvtx.end_range(rng_first)
first_token = False
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
finally:
if first_token: if first_token:
_nvtx.end_range(rng_first) _nvtx.end_range(rng_first)
first_token = False _nvtx.end_range(rng_decode)
logger.debug(
f"Response kv_transfer_params: {response.kv_transfer_params}"
)
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
finally:
if first_token:
_nvtx.end_range(rng_first)
_nvtx.end_range(rng_decode)
...@@ -18,6 +18,7 @@ from dynamo.common.multimodal.embedding_transfer import ( ...@@ -18,6 +18,7 @@ from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver, AbstractEmbeddingReceiver,
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
) )
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import Client from dynamo.runtime import Client
from .encode_utils import get_embedding_hash from .encode_utils import get_embedding_hash
...@@ -163,41 +164,48 @@ async def _fetch_from_encode_workers( ...@@ -163,41 +164,48 @@ async def _fetch_from_encode_workers(
multimodal_inputs=[], multimodal_inputs=[],
) )
batch: List[MultiModalGroup] = [] with time_and_log_code_section(f"[PREFILL] request: {request_id} dispatch encode"):
encode_response_streams = [] batch: List[MultiModalGroup] = []
for url in image_urls: encode_response_streams = []
multimodal_input = MultiModalInput() for url in image_urls:
multimodal_input.image_url = url multimodal_input = MultiModalInput()
batch.append(MultiModalGroup(multimodal_input=multimodal_input)) multimodal_input.image_url = url
batch.append(MultiModalGroup(multimodal_input=multimodal_input))
if len(batch) >= encode_batch_size:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
batch = []
if len(batch) >= encode_batch_size: if batch:
encode_request.multimodal_inputs = batch encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json() payload = encode_request.model_dump_json()
encode_response_streams.append( encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type] await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
) )
batch = []
if batch:
encode_request.multimodal_inputs = batch
payload = encode_request.model_dump_json()
encode_response_streams.append(
await encode_worker_client.round_robin(payload) # type: ignore[arg-type]
)
multimodal_groups: List[MultiModalGroup] = []
for stream in encode_response_streams:
async for response in stream:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
tasks = [ with time_and_log_code_section(
asyncio.create_task(receiver.receive_embeddings(group.serialized_request)) f"[PREFILL] request: {request_id} receive encode responses"
for group in multimodal_groups ):
] multimodal_groups: List[MultiModalGroup] = []
loaded = await asyncio.gather(*tasks) for stream in encode_response_streams:
async for response in stream:
logger.debug(f"Received response from encode worker: {response}")
output = vLLMMultimodalRequest.model_validate_json(response.data()) # type: ignore[attr-defined]
if output.multimodal_inputs:
multimodal_groups.extend(output.multimodal_inputs)
with time_and_log_code_section(
f"[PREFILL] request: {request_id} receive embeddings"
):
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
]
loaded = await asyncio.gather(*tasks)
is_local = isinstance(receiver, LocalEmbeddingReceiver) is_local = isinstance(receiver, LocalEmbeddingReceiver)
pending: _PendingRelease | None = None if is_local else _PendingRelease(receiver) pending: _PendingRelease | None = None if is_local else _PendingRelease(receiver)
...@@ -302,15 +310,18 @@ async def load_multimodal_embeddings( ...@@ -302,15 +310,18 @@ async def load_multimodal_embeddings(
) )
multi_modal_data: Dict[str, Any] = defaultdict(list) multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups: with time_and_log_code_section(
assert group.loaded_embedding is not None f"[PREFILL] request: {request_id} accumulate embeddings"
_accumulate_embeddings( ):
multi_modal_data, for group in groups:
model, assert group.loaded_embedding is not None
embeddings_dtype, _accumulate_embeddings(
group.loaded_embedding, multi_modal_data,
group.image_grid_thw, model,
) embeddings_dtype,
group.loaded_embedding,
group.image_grid_thw,
)
if pending is not None: if pending is not None:
# Multi-image: torch.cat in _accumulate_embeddings already created # Multi-image: torch.cat in _accumulate_embeddings already created
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment