Unverified Commit 39264545 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Decouple TimingContext from InputProcessingContext (#35083)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1e8438a8
......@@ -389,13 +389,13 @@ def _test_processing_correctness_one(
mm_items = baseline_processor.info.parse_mm_data(mm_data)
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
baseline_tokenized_result = baseline_processor.apply(
baseline_tokenized_result = baseline_processor(
token_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
)
cached_tokenized_result = cached_processor.apply(
cached_tokenized_result = cached_processor(
token_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
......@@ -409,12 +409,12 @@ def _test_processing_correctness_one(
)
if text_prompt is not None:
baseline_text_result = baseline_processor.apply(
baseline_text_result = baseline_processor(
text_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
)
cached_text_result = cached_processor.apply(
cached_text_result = cached_processor(
text_prompt,
mm_items=mm_items,
hf_processor_mm_kwargs={},
......
......@@ -176,7 +176,7 @@ def test_get_image_size_with_most_features(
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -52,7 +52,7 @@ def test_processor_override(
metadata["fps"] = fps
mm_data = {"video": [(video, metadata)]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -104,12 +104,12 @@ def test_video_loader_consistency(
static_mm_data = {"video": [(static_video, static_metadata)]}
dynamic_mm_data = {"video": [(dynamic_video, dynamic_metadata)]}
static_outputs = processor.apply(
static_outputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(static_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
dynamic_outputs = processor.apply(
dynamic_outputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(dynamic_mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -106,7 +106,7 @@ def _run_check(
for image in images
)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,
......
......@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -66,7 +66,7 @@ def _run_check(
for image in images
)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,
......
......@@ -49,7 +49,7 @@ def test_processor_override(
if tokenized_prompt:
prompt = tokenizer.encode(prompt)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,
......
......@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......
......@@ -87,7 +87,7 @@ def _validate_image_prompt_replacements_one(
try:
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......
......@@ -29,7 +29,7 @@ def test_processor_override(
image = Image.new("RGB", size=(364, 364))
mm_data = {"image": [image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......@@ -50,7 +50,7 @@ def _validate_image_prompt_replacements_one(
mm_data = {"image": [image] * num_imgs}
try:
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......
......@@ -68,7 +68,7 @@ def _run_check(
for image in images
)
print(total_expected_num_patches)
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=mm_processor_kwargs,
......
......@@ -47,7 +47,7 @@ def test_processor_override(
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -51,7 +51,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -42,7 +42,7 @@ def test_processor_override(
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -88,7 +88,7 @@ def test_get_image_size_with_most_features(
prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -51,7 +51,7 @@ def test_processor_with_audio_sample_rate(
hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate,
}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -94,7 +94,7 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None:
hf_processor_mm_kwargs: dict[str, Any] = {
"audio_sample_rate": audio_sample_rate,
}
processed = processor.apply(
processed = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -61,7 +61,7 @@ def test_processor_override(
dummy_image = image_assets[0].pil_image.resize(dummy_image_size)
mm_data = {"image": [dummy_image] * num_imgs}
processed_inputs = processor.apply(
processed_inputs = processor(
prompt,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......
......@@ -99,7 +99,7 @@ def create_batched_mm_kwargs(
mm_counts=mm_counts,
mm_options={},
)
mm_items = processor_inputs.mm_items
mm_items = processor_inputs.mm_data_items
resized_mm_data = {
modality: resize_mm_data(items.data, size_factors)
for modality, items in mm_items.items()
......@@ -108,11 +108,10 @@ def create_batched_mm_kwargs(
# video metadata will be added back to the resized video data here.
text_prompt, token_prompt = get_text_token_prompts(processor, resized_mm_data)
mm_kwargs = processor.apply(
mm_kwargs = processor(
prompt=token_prompt if text_prompt is None else text_prompt,
mm_items=processor.info.parse_mm_data(resized_mm_data),
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"].require_data()
return group_mm_kwargs_by_modality(
......
......@@ -19,7 +19,7 @@ def test_multimodal_processor(model_id):
image_pil = ImageAsset("cherry_blossom").pil_image
mm_data = {"image": image_pil}
str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501
str_processed_inputs = mm_processor.apply(
str_processed_inputs = mm_processor(
prompt=str_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......@@ -44,7 +44,7 @@ def test_multimodal_processor(model_id):
77091,
198,
]
ids_processed_inputs = mm_processor.apply(
ids_processed_inputs = mm_processor(
prompt=ids_prompt,
mm_items=mm_processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......
......@@ -934,7 +934,7 @@ def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
with exc_ctx:
processor.apply(
processor(
"<image>" * num_images,
mm_items=processor.info.parse_mm_data(mm_data),
hf_processor_mm_kwargs={},
......
......@@ -17,8 +17,9 @@ import argparse
import dataclasses
import json
import time
from collections import defaultdict
from datetime import datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
......@@ -59,12 +60,13 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
Example:
{
'request-123': {
'hf_processor_time': 0.45,
'hashing_time': 0.02,
'cache_lookup_time': 0.01,
'prompt_update_time': 0.03,
'preprocessor_total_time': 0.51,
'encoder_forward_time': 0.23,
'get_mm_hashes_secs': 0.02,
'get_cache_missing_items_secs': 0.01,
'apply_hf_processor_secs': 0.45,
'merge_mm_kwargs_secs': 0.01,
'apply_prompt_updates_secs': 0.03,
'preprocessor_total_secs': 0.51,
'encoder_forward_secs': 0.23,
'num_encoder_calls': 1
}
}
......@@ -74,8 +76,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return {}
renderer = llm_engine.renderer
mm_processor = renderer.get_mm_processor()
preprocessing_stats = mm_processor.info.ctx.get_all_timing_stats()
mm_processor_stats = renderer._mm_timing_registry.stat()
encoder_stats = dict[str, dict[str, float]]()
for worker_stats in llm_engine.collective_rpc("get_encoder_timing_stats"):
......@@ -88,10 +89,10 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
else:
# Aggregate timing metrics across workers
current_time = encoder_stats[request_id].get(
"encoder_forward_time", 0.0
"encoder_forward_secs", 0.0
)
new_time = stats_dict.get("encoder_forward_time", 0.0)
encoder_stats[request_id]["encoder_forward_time"] = max(
new_time = stats_dict.get("encoder_forward_secs", 0.0)
encoder_stats[request_id]["encoder_forward_secs"] = max(
current_time, new_time
)
......@@ -103,7 +104,7 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
merged_stats = dict[str, dict[str, float]]()
for request_id, prep_dict in preprocessing_stats.items():
for request_id, prep_dict in mm_processor_stats.items():
merged_stats[request_id] = dict(prep_dict)
for request_id, enc_dict in encoder_stats.items():
......@@ -124,34 +125,18 @@ def get_timing_stats_from_engine(llm_engine: LLMEngine) -> dict[str, dict[str, f
return merged_stats
def collect_mm_processor_stats(
llm_engine: LLMEngine,
num_warmup_reqs: int = 0,
) -> dict[str, list[float]]:
def collect_mm_processor_stats(llm_engine: LLMEngine) -> dict[str, list[float]]:
"""
Collect multimodal processor timing stats.
Returns a dictionary mapping stage names to lists of timing values (in seconds).
"""
all_stats = get_timing_stats_from_engine(llm_engine)
stat_keys = [
"hf_processor_time",
"hashing_time",
"cache_lookup_time",
"prompt_update_time",
"preprocessor_total_time",
"encoder_forward_time",
"num_encoder_calls",
]
stats_by_stage = {key: [] for key in stat_keys}
# Skip warmup requests
stats_list = list(all_stats.values())[num_warmup_reqs:]
stats_by_stage = defaultdict[str, list[float]](list)
for stats_dict in stats_list:
for key in stat_keys:
if key in stats_dict:
stats_by_stage[key].append(stats_dict[key])
for stats_dict in all_stats.values():
for stat_key, stat_val in stats_dict.items():
stats_by_stage[stat_key].append(stat_val)
return stats_by_stage
......@@ -159,13 +144,20 @@ def collect_mm_processor_stats(
def calculate_mm_processor_metrics(
stats_by_stage: dict[str, list[float]],
selected_percentiles: list[float],
*,
unit: Literal["us", "ms", "s"] = "ms",
) -> dict[str, dict[str, float]]:
"""
Calculate aggregate metrics from stats by stage.
"""
unit2mult = {"us": 1000000, "ms": 1000, "s": 1}
unit_mult = unit2mult[unit]
metrics = {}
for stage_name, times in stats_by_stage.items():
for stage, times in stats_by_stage.items():
stage_name = stage.replace("_secs", "_" + unit)
if not times:
metrics[stage_name] = {
"mean": 0.0,
......@@ -175,8 +167,8 @@ def calculate_mm_processor_metrics(
}
continue
is_count_metric = stage_name == "num_encoder_calls"
values = times if is_count_metric else [t * 1000 for t in times]
is_count_metric = stage == "num_encoder_calls"
values = times if is_count_metric else [t * unit_mult for t in times]
metrics[stage_name] = {
"mean": float(np.mean(values)),
......@@ -285,6 +277,9 @@ def benchmark_multimodal_processor(
use_tqdm=not getattr(args, "disable_tqdm", False),
)
# Clear stats from warmup requests
collect_mm_processor_stats(llm.llm_engine)
print(f"Processing {len(prompts)} requests...")
start_time = time.perf_counter()
......@@ -295,7 +290,7 @@ def benchmark_multimodal_processor(
end_time = time.perf_counter()
total_time = end_time - start_time
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine, num_warmups)
mm_stats_by_stage = collect_mm_processor_stats(llm.llm_engine)
if not any(mm_stats_by_stage.values()):
print(
......@@ -475,11 +470,8 @@ def main(args: argparse.Namespace) -> None:
]
mm_data = []
for stage, metrics in result["mm_processor_stats"].items():
is_count = stage == "num_encoder_calls"
unit = "" if is_count else " (ms)"
row = {
"Stage": stage + unit,
"Stage": stage,
"Mean": f"{metrics['mean']:.2f}",
"Median": f"{metrics['median']:.2f}",
"Std": f"{metrics['std']:.2f}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment