Unverified Commit 39264545 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Decouple TimingContext from InputProcessingContext (#35083)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1e8438a8
...@@ -389,13 +389,13 @@ def _test_processing_correctness_one( ...@@ -389,13 +389,13 @@ def _test_processing_correctness_one(
mm_items = baseline_processor.info.parse_mm_data(mm_data) mm_items = baseline_processor.info.parse_mm_data(mm_data)
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]()) ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
baseline_tokenized_result = baseline_processor.apply( baseline_tokenized_result = baseline_processor(
token_prompt, token_prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
cached_tokenized_result = cached_processor.apply( cached_tokenized_result = cached_processor(
token_prompt, token_prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
...@@ -409,12 +409,12 @@ def _test_processing_correctness_one( ...@@ -409,12 +409,12 @@ def _test_processing_correctness_one(
) )
if text_prompt is not None: if text_prompt is not None:
baseline_text_result = baseline_processor.apply( baseline_text_result = baseline_processor(
text_prompt, text_prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
) )
cached_text_result = cached_processor.apply( cached_text_result = cached_processor(
text_prompt, text_prompt,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -176,7 +176,7 @@ def test_get_image_size_with_most_features( ...@@ -176,7 +176,7 @@ def test_get_image_size_with_most_features(
for asset in image_assets: for asset in image_assets:
mm_data = {"image": [asset.pil_image]} mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -52,7 +52,7 @@ def test_processor_override( ...@@ -52,7 +52,7 @@ def test_processor_override(
metadata["fps"] = fps metadata["fps"] = fps
mm_data = {"video": [(video, metadata)]} mm_data = {"video": [(video, metadata)]}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -104,12 +104,12 @@ def test_video_loader_consistency( ...@@ -104,12 +104,12 @@ def test_video_loader_consistency(
static_mm_data = {"video": [(static_video, static_metadata)]} static_mm_data = {"video": [(static_video, static_metadata)]}
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]} dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}
static_outputs = processor.apply( static_outputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(static_mm_data), mm_items=processor.info.parse_mm_data(static_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
dynamic_outputs = processor.apply( dynamic_outputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(dynamic_mm_data), mm_items=processor.info.parse_mm_data(dynamic_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -106,7 +106,7 @@ def _run_check( ...@@ -106,7 +106,7 @@ def _run_check(
for image in images for image in images
) )
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
......
...@@ -61,7 +61,7 @@ def test_processor_override( ...@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size) dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs} mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -66,7 +66,7 @@ def _run_check( ...@@ -66,7 +66,7 @@ def _run_check(
for image in images for image in images
) )
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
......
...@@ -49,7 +49,7 @@ def test_processor_override( ...@@ -49,7 +49,7 @@ def test_processor_override(
if tokenized_prompt: if tokenized_prompt:
prompt = tokenizer.encode(prompt) prompt = tokenizer.encode(prompt)
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
......
...@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one( ...@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try: try:
# The processor will throw an error if there is a mismatch # The processor will throw an error if there is a mismatch
# in the prompt replacements # in the prompt replacements
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one( ...@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try: try:
# The processor will throw an error if there is a mismatch # The processor will throw an error if there is a mismatch
# in the prompt replacements # in the prompt replacements
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -29,7 +29,7 @@ def test_processor_override( ...@@ -29,7 +29,7 @@ def test_processor_override(
image = Image.new("RGB", size=(364, 364)) image = Image.new("RGB", size=(364, 364))
mm_data = {"image": [image] * num_imgs} mm_data = {"image": [image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
...@@ -50,7 +50,7 @@ def _validate_image_prompt_replacements_one( ...@@ -50,7 +50,7 @@ def _validate_image_prompt_replacements_one(
mm_data = {"image": [image] * num_imgs} mm_data = {"image": [image] * num_imgs}
try: try:
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -68,7 +68,7 @@ def _run_check( ...@@ -68,7 +68,7 @@ def _run_check(
for image in images for image in images
) )
print(total_expected_num_patches) print(total_expected_num_patches)
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs, hf_processor_mm_kwargs=mm_processor_kwargs,
......
...@@ -47,7 +47,7 @@ def test_processor_override( ...@@ -47,7 +47,7 @@ def test_processor_override(
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs} mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -51,7 +51,7 @@ def test_processor_override( ...@@ -51,7 +51,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size) dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs} mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -42,7 +42,7 @@ def test_processor_override( ...@@ -42,7 +42,7 @@ def test_processor_override(
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs} mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -88,7 +88,7 @@ def test_get_image_size_with_most_features( ...@@ -88,7 +88,7 @@ def test_get_image_size_with_most_features(
prompt = "<|vision_start|><|image_pad|><|vision_end|>" prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets: for asset in image_assets:
mm_data = {"image": [asset.pil_image]} mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -51,7 +51,7 @@ def test_processor_with_audio_sample_rate( ...@@ -51,7 +51,7 @@ def test_processor_with_audio_sample_rate(
hf_processor_mm_kwargs: dict[str, Any] = { hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate, "audio_sample_rate": audio_sample_rate,
} }
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -94,7 +94,7 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None: ...@@ -94,7 +94,7 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None:
hf_processor_mm_kwargs: dict[str, Any] = { hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate, "audio_sample_rate": audio_sample_rate,
} }
processed = processor.apply( processed = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -61,7 +61,7 @@ def test_processor_override( ...@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size) dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs} mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply( processed_inputs = processor(
prompt, prompt,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
...@@ -99,7 +99,7 @@ def create_batched_mm_kwargs( ...@@ -99,7 +99,7 @@ def create_batched_mm_kwargs(
mm_counts=mm_counts, mm_counts=mm_counts,
mm_options={}, mm_options={},
) )
mm_items = processor_inputs.mm_items mm_items = processor_inputs.mm_data_items
resized_mm_data = { resized_mm_data = {
modality: resize_mm_data(items.data, size_factors) modality: resize_mm_data(items.data, size_factors)
for modality, items in mm_items.items() for modality, items in mm_items.items()
...@@ -108,11 +108,10 @@ def create_batched_mm_kwargs( ...@@ -108,11 +108,10 @@ def create_batched_mm_kwargs(
# video metadata will be added back to the resized video data here. # video metadata will be added back to the resized video data here.
text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data) text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data)
mm_kwargs = processor.apply( mm_kwargs = processor(
prompt=token_prompt if text_prompt is None else text_prompt, prompt=token_prompt if text_prompt is None else text_prompt,
mm_items=processor.info.parse_mm_data(resized_mm_data), mm_items=processor.info.parse_mm_data(resized_mm_data),
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"].require_data() )["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality( return group_mm_kwargs_by_modality(
......
...@@ -19,7 +19,7 @@ def test_multimodal_processor(model_id): ...@@ -19,7 +19,7 @@ def test_multimodal_processor(model_id):
image_pil = ImageAsset("cherry_blossom").pil_image image_pil = ImageAsset("cherry_blossom").pil_image
mm_data = {"image": image_pil} mm_data = {"image": image_pil}
str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501
str_processed_inputs = mm_processor.apply( str_processed_inputs = mm_processor(
prompt=str_prompt, prompt=str_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data), mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
...@@ -44,7 +44,7 @@ def test_multimodal_processor(model_id): ...@@ -44,7 +44,7 @@ def test_multimodal_processor(model_id):
77091, 77091,
198, 198,
] ]
ids_processed_inputs = mm_processor.apply( ids_processed_inputs = mm_processor(
prompt=ids_prompt, prompt=ids_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data), mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -934,7 +934,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid): ...@@ -934,7 +934,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most") exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx: with exc_ctx:
processor.apply( processor(
"<image>" * num_images, "<image>" * num_images,
mm_items=processor.info.parse_mm_data(mm_data), mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={}, hf_processor_mm_kwargs={},
......
...@@ -17,8 +17,9 @@ import argparse ...@@ -17,8 +17,9 @@ import argparse
import dataclasses import dataclasses
import json import json
import time import time
from collections import defaultdict
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Literal
import numpy as np import numpy as np
...@@ -59,12 +60,13 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f ...@@ -59,12 +60,13 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
Example: Example:
{ {
'request-123': { 'request-123': {
'hf_processor_time': 0.45, 'get_mm_hashes_secs': 0.02,
'hashing_time': 0.02, 'get_cache_missing_items_secs': 0.01,
'cache_lookup_time': 0.01, 'apply_hf_processor_secs': 0.45,
'prompt_update_time': 0.03, 'merge_mm_kwargs_secs': 0.01,
'preprocessor_total_time': 0.51, 'apply_prompt_updates_secs': 0.03,
'encoder_forward_time': 0.23, 'preprocessor_total_secs': 0.51,
'encoder_forward_secs': 0.23,
'num_encoder_calls': 1 'num_encoder_calls': 1
} }
} }
...@@ -74,8 +76,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f ...@@ -74,8 +76,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return {} return {}
renderer = llm_engine.renderer renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor() mm_processor_stats = renderer._mm_timing_registry.stat()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
encoder_stats = dict[str, dict[str, float]]() encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"): for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
...@@ -88,10 +89,10 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f ...@@ -88,10 +89,10 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
else: else:
# Aggregate timing metrics across workers # Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get( current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0 "encoder_forward_secs", 0.0
) )
new_time = stats_dict.get("encoder_forward_time", 0.0) new_time = stats_dict.get("encoder_forward_secs", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max( encoder_stats[request_id]["encoder_forward_secs"] = max(
current_time, new_time current_time, new_time
) )
...@@ -103,7 +104,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f ...@@ -103,7 +104,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
merged_stats = dict[str, dict[str, float]]() merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items(): for request_id, prep_dict in mm_processor_stats.items():
merged_stats[request_id] = dict(prep_dict) merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items(): for request_id, enc_dict in encoder_stats.items():
...@@ -124,34 +125,18 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f ...@@ -124,34 +125,18 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return merged_stats return merged_stats
def collect_mm_processor_stats( def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
llm_engine: LLMEngine,
num_warmup_reqs: int = 0,
) -> dict[str, list[float]]:
""" """
Collect multimodal processor timing stats. Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds). Returns a dictionary mapping stage names to lists of timing values (in seconds).
""" """
all_stats = get_timing_stats_from_engine(llm_engine) all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [ stats_by_stage = defaultdict[str, list[float]](list)
"hf_processor_time",
"hashing_time",
"cache_lookup_time",
"prompt_update_time",
"preprocessor_total_time",
"encoder_forward_time",
"num_encoder_calls",
]
stats_by_stage = {key: [] for key in stat_keys}
# Skip warmup requests
stats_list = list(all_stats.values())[num_warmup_reqs:]
for stats_dict in stats_list: for stats_dict in all_stats.values():
for key in stat_keys: for stat_key, stat_val in stats_dict.items():
if key in stats_dict: stats_by_stage[stat_key].append(stat_val)
stats_by_stage[key].append(stats_dict[key])
return stats_by_stage return stats_by_stage
...@@ -159,13 +144,20 @@ def collect_mm_processor_stats( ...@@ -159,13 +144,20 @@ def collect_mm_processor_stats(
def calculate_mm_processor_metrics( def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]], stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float], selected_percentiles: list[float],
*,
unit: Literal["us", "ms", "s"] = "ms",
) -> dict[str, dict[str, float]]: ) -> dict[str, dict[str, float]]:
""" """
Calculate aggregate metrics from stats by stage. Calculate aggregate metrics from stats by stage.
""" """
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
unit_mult = unit2mult[unit]
metrics = {} metrics = {}
for stage_name, times in stats_by_stage.items(): for stage, times in stats_by_stage.items():
stage_name = stage.replace("_secs", "_" + unit)
if not times: if not times:
metrics[stage_name] = { metrics[stage_name] = {
"mean": 0.0, "mean": 0.0,
...@@ -175,8 +167,8 @@ def calculate_mm_processor_metrics( ...@@ -175,8 +167,8 @@ def calculate_mm_processor_metrics(
} }
continue continue
is_count_metric = stage_name == "num_encoder_calls" is_count_metric = stage == "num_encoder_calls"
values = times if is_count_metric else [t * 1000 for t in times] values = times if is_count_metric else [t * unit_mult for t in times]
metrics[stage_name] = { metrics[stage_name] = {
"mean": float(np.mean(values)), "mean": float(np.mean(values)),
...@@ -285,6 +277,9 @@ def benchmark_multimodal_processor( ...@@ -285,6 +277,9 @@ def benchmark_multimodal_processor(
use_tqdm=not getattr(args, "disable_tqdm", False), use_tqdm=not getattr(args, "disable_tqdm", False),
) )
# Clear stats from warmup requests
collect_mm_processor_stats(llm.llm_engine)
print(f"Processing {len(prompts)} requests...") print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter() start_time = time.perf_counter()
...@@ -295,7 +290,7 @@ def benchmark_multimodal_processor( ...@@ -295,7 +290,7 @@ def benchmark_multimodal_processor(
end_time = time.perf_counter() end_time = time.perf_counter()
total_time = end_time - start_time total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine, num_warmups) mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
if not any(mm_stats_by_stage.values()): if not any(mm_stats_by_stage.values()):
print( print(
...@@ -475,11 +470,8 @@ def main(args: argparse.Namespace) -> None: ...@@ -475,11 +470,8 @@ def main(args: argparse.Namespace) -> None:
] ]
mm_data = [] mm_data = []
for stage, metrics in result["mm_processor_stats"].items(): for stage, metrics in result["mm_processor_stats"].items():
is_count = stage == "num_encoder_calls"
unit = "" if is_count else " (ms)"
row = { row = {
"Stage": stage + unit, "Stage": stage,
"Mean": f"{metrics['mean']:.2f}", "Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}", "Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}", "Std": f"{metrics['std']:.2f}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment