Unverified Commit fd3034da authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[VLM] Optimize qwen_vl preprocess_video (#12240)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatar羽癫 <yudian.zy@antgroup.com>
parent fbbe16fa
...@@ -2,8 +2,10 @@ import asyncio ...@@ -2,8 +2,10 @@ import asyncio
import math import math
import os import os
import re import re
import time
from typing import List, Union from typing import List, Union
import numpy as np
import torch import torch
import torchvision import torchvision
from PIL import Image from PIL import Image
...@@ -175,12 +177,15 @@ async def preprocess_video( ...@@ -175,12 +177,15 @@ async def preprocess_video(
image_factor: int = IMAGE_FACTOR, image_factor: int = IMAGE_FACTOR,
# vr: VideoReader, image_factor: int = IMAGE_FACTOR # vr: VideoReader, image_factor: int = IMAGE_FACTOR
) -> torch.Tensor: ) -> torch.Tensor:
entry_time = time.perf_counter()
ele = {} ele = {}
total_frames, video_fps = len(vr), vr.get_avg_fps() total_frames, video_fps = len(vr), vr.get_avg_fps()
nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps) nframes = smart_nframes({}, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() idx = np.linspace(0, total_frames - 1, num=nframes, dtype=np.int64)
video = vr.get_batch(idx).asnumpy() idx = np.unique(idx)
video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format video_np = vr.get_batch(idx).asnumpy()
video = torch.from_numpy(video_np).pin_memory()
video = video.permute(0, 3, 1, 2) # Convert to TCHW format
nframes, _, height, width = video.shape nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
...@@ -188,6 +193,9 @@ async def preprocess_video( ...@@ -188,6 +193,9 @@ async def preprocess_video(
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
int(min_pixels * 1.05), int(min_pixels * 1.05),
) )
get_batch_time = time.perf_counter()
max_pixels_supposed = ele.get("max_pixels", max_pixels) max_pixels_supposed = ele.get("max_pixels", max_pixels)
if max_pixels_supposed > max_pixels: if max_pixels_supposed > max_pixels:
logger.warning( logger.warning(
...@@ -208,12 +216,13 @@ async def preprocess_video( ...@@ -208,12 +216,13 @@ async def preprocess_video(
min_pixels=min_pixels, min_pixels=min_pixels,
max_pixels=max_pixels, max_pixels=max_pixels,
) )
smart_resize_time = time.perf_counter()
video = torchvision.transforms.functional.resize( video = torchvision.transforms.functional.resize(
video, video,
[resized_height, resized_width], [resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC, interpolation=InterpolationMode.BILINEAR,
antialias=True, )
).float() video = video.pin_memory()
video_metadata = { video_metadata = {
"fps": video_fps, "fps": video_fps,
"duration": total_frames / video_fps, "duration": total_frames / video_fps,
...@@ -221,6 +230,14 @@ async def preprocess_video( ...@@ -221,6 +230,14 @@ async def preprocess_video(
"frames_indices": idx, "frames_indices": idx,
"video_backend": "torchvision", "video_backend": "torchvision",
} }
torchvision_resize_time = time.perf_counter()
logger.debug(
f"[preprocess_video Perf], "
f"get_batch_time: {(get_batch_time - entry_time) * 1000:.2f} ms, "
f"smart_resize_time: {(smart_resize_time - get_batch_time) * 1000:.2f} ms, "
f"torchvision_resize_time: {(torchvision_resize_time - smart_resize_time) * 1000:.2f} ms, "
f"total_time: {(torchvision_resize_time - entry_time) * 1000:.2f} ms"
)
return video, video_metadata return video, video_metadata
...@@ -273,6 +290,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor): ...@@ -273,6 +290,7 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
*args, *args,
**kwargs, **kwargs,
): ):
entry_time = time.perf_counter()
base_output = self.load_mm_data( base_output = self.load_mm_data(
prompt=input_text, prompt=input_text,
image_data=image_data, image_data=image_data,
...@@ -280,6 +298,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor): ...@@ -280,6 +298,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
audio_data=request_obj.audio_data, audio_data=request_obj.audio_data,
multimodal_tokens=self.mm_tokens, multimodal_tokens=self.mm_tokens,
) )
load_time = time.perf_counter()
rid = getattr(request_obj, "rid", "anonymous_rid")
# Qwen-specific: resize images if they are raw Image objects # Qwen-specific: resize images if they are raw Image objects
if base_output.images and isinstance(base_output.images[0], Image.Image): if base_output.images and isinstance(base_output.images[0], Image.Image):
...@@ -288,10 +308,12 @@ class QwenVLImageProcessor(SGLangBaseProcessor): ...@@ -288,10 +308,12 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
video_metadata = None video_metadata = None
if base_output.videos: if base_output.videos:
video_results = await asyncio.gather( videos_processed = [
*[preprocess_video(video) for video in base_output.videos] await preprocess_video(video) for video in base_output.videos
) ]
base_output.videos, video_metadata = map(list, zip(*video_results)) base_output.videos, video_metadata = map(list, zip(*videos_processed))
preprocess_time = time.perf_counter()
# NOTE: for qwen3-vl, video_meta need to be passed in, since do_sample_frames is already done in preprocess_video # NOTE: for qwen3-vl, video_meta need to be passed in, since do_sample_frames is already done in preprocess_video
if self.hf_config.model_type in ("qwen3_vl", "qwen3_vl_moe"): if self.hf_config.model_type in ("qwen3_vl", "qwen3_vl_moe"):
...@@ -319,6 +341,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor): ...@@ -319,6 +341,8 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
ret, "video_second_per_grid", None ret, "video_second_per_grid", None
) )
process_time = time.perf_counter()
input_ids = input_ids.flatten() input_ids = input_ids.flatten()
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index( mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index(
...@@ -343,6 +367,15 @@ class QwenVLImageProcessor(SGLangBaseProcessor): ...@@ -343,6 +367,15 @@ class QwenVLImageProcessor(SGLangBaseProcessor):
), ),
) )
mrope_positions = mrope_positions.squeeze(1) mrope_positions = mrope_positions.squeeze(1)
get_rope_index_time = time.perf_counter()
logger.debug(
f"[QwenVLProcessor Perf] {rid=}, "
f"load_time: {(load_time - entry_time) * 1000:.2f} ms, "
f"preprocess_time: {(preprocess_time - load_time) * 1000:.2f} ms, "
f"process_time: {(process_time - preprocess_time) * 1000:.2f} ms, "
f"get_rope_index_time: {(get_rope_index_time - process_time) * 1000:.2f} ms, "
f"total_time: {(get_rope_index_time - entry_time) * 1000:.2f} ms"
)
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids.tolist(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment