Unverified Commit fd3034da authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[VLM] Optimize qwen_vl preprocess_video (#12240)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
parent fbbe16fa
......@@ -2,8 +2,10 @@ import asyncio
import math
import os
import re
import time
from typing import List, Union
import numpy as np
import torch
import torchvision
from PIL import Image
......@@ -175,12 +177,15 @@ async def preprocess_video(
image_factor: int = IMAGE_FACTOR,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor:
entry_time = time.perf_counter()
ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
video = vr.get_batch(idx).asnumpy()
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format
idx = np.linspace(0, total_frames - 1, num=nframes, dtype=np.int64)
idx = np.unique(idx)
video_np = vr.get_batch(idx).asnumpy()
video = torch.from_numpy(video_np).pin_memory()
video = video.permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
......@@ -188,6 +193,9 @@ async def preprocess_video(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05),
)
get_batch_time = time.perf_counter()
max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels:
logger.warning(
......@@ -208,12 +216,13 @@ async def preprocess_video(
min_pixels=min_pixels,
max_pixels=max_pixels,
)
smart_resize_time = time.perf_counter()
video = torchvision.transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
interpolation=InterpolationMode.BILINEAR,
)
video = video.pin_memory()
video_metadata = {
"fps": video_fps,
"duration": total_frames / video_fps,
......@@ -221,6 +230,14 @@ async def preprocess_video(
"frames_indices": idx,
"video_backend": "torchvision",
}
torchvision_resize_time = time.perf_counter()
logger.debug(
f"[preprocess_video Perf], "
f"get_batch_time: {(get_batch_time - entry_time) * 1000:.2f} ms, "
f"smart_resize_time: {(smart_resize_time - get_batch_time) * 1000:.2f} ms, "
f"torchvision_resize_time: {(torchvision_resize_time - smart_resize_time) * 1000:.2f} ms, "
f"total_time: {(torchvision_resize_time - entry_time) * 1000:.2f} ms"
)
return video, video_metadata
......@@ -273,6 +290,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
*args,
**kwargs,
):
entry_time = time.perf_counter()
base_output = self.load_mm_data(
prompt=input_text,
image_data=image_data,
......@@ -280,6 +298,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens,
)
load_time = time.perf_counter()
rid = getattr(request_obj, "rid", "anonymous_rid")
# Qwen-specific: resize images if they are raw Image objects
if base_output.images and isinstance(base_output.images[0], Image.Image):
......@@ -288,10 +308,12 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
video_metadata = None
if base_output.videos:
video_results = await asyncio.gather(
*[preprocess_video(video) for video in base_output.videos]
)
base_output.videos, video_metadata = map(list, zip(*video_results))
videos_processed = [
await preprocess_video(video) for video in base_output.videos
]
base_output.videos, video_metadata = map(list, zip(*videos_processed))
preprocess_time = time.perf_counter()
# NOTE: for qwen3-vl, video_meta need to be passed in, since do_sample_frames is already done in preprocess_video
if self.hf_config.model_type in ("qwen3_vl", "qwen3_vl_moe"):
......@@ -319,6 +341,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
ret, "video_second_per_grid", None
)
process_time = time.perf_counter()
input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
......@@ -343,6 +367,15 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
),
)
mrope_positions = mrope_positions.squeeze(1)
get_rope_index_time = time.perf_counter()
logger.debug(
f"[QwenVLProcessor Perf] {rid=}, "
f"load_time: {(load_time - entry_time) * 1000:.2f} ms, "
f"preprocess_time: {(preprocess_time - load_time) * 1000:.2f} ms, "
f"process_time: {(process_time - preprocess_time) * 1000:.2f} ms, "
f"get_rope_index_time: {(get_rope_index_time - process_time) * 1000:.2f} ms, "
f"total_time: {(get_rope_index_time - entry_time) * 1000:.2f} ms"
)
return {
"input_ids": input_ids.tolist(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment