Unverified Commit 15f99347 authored by Kevin Tuan's avatar Kevin Tuan Committed by GitHub
Browse files

refactor(InternVL): Use gpu to preprocess the input image (#9795)

parent bcf1955f
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import numpy as np import numpy as np
import torch import torch
from decord import VideoReader, cpu import torchvision.transforms as T
from decord import VideoReader, cpu, gpu
from PIL import Image from PIL import Image
from torchvision.transforms import InterpolationMode
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
...@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN), image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
).build(_image_processor) ).build(_image_processor)
@staticmethod
def build_transform(input_size):
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def resize_image(img, size):
return img.resize((size, size), Image.Resampling.BICUBIC)
def to_tensor(img):
# Convert PIL Image to numpy array
img_array = np.array(img).astype(np.float32) / 255.0
# Convert HWC to CHW format
img_array = img_array.transpose(2, 0, 1)
return torch.from_numpy(img_array)
def normalize(tensor, mean, std):
mean = torch.tensor(mean).view(-1, 1, 1)
std = torch.tensor(std).view(-1, 1, 1)
return (tensor - mean) / std
def transform(img):
img = img.convert("RGB") if img.mode != "RGB" else img
img = resize_image(img, input_size)
tensor = to_tensor(img)
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
return tensor
return transform
@staticmethod
def dynamic_preprocess(
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
):
def find_closest_aspect_ratio(
aspect_ratio, target_ratios, width, height, image_size
):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
@staticmethod @staticmethod
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound: if bound:
...@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
@staticmethod @staticmethod
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) try:
vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
use_gpu = True
except (RuntimeError, OSError) as e:
print(
f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
)
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
use_gpu = False
max_frame = len(vr) - 1 max_frame = len(vr) - 1
fps = float(vr.get_avg_fps()) fps = float(vr.get_avg_fps())
pixel_values_list, num_patches_list = [], [] pixel_values_list = []
transform = InternVLImageProcessor.build_transform(input_size=input_size) num_patches_list = []
frame_indices = InternVLImageProcessor.get_index( frame_indices = InternVLImageProcessor.get_index(
bound, fps, max_frame, first_idx=0, num_segments=num_segments bound, fps, max_frame, first_idx=0, num_segments=num_segments
) )
for frame_index in frame_indices: for frame_index in frame_indices:
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB") # Load frame
img = InternVLImageProcessor.dynamic_preprocess( frame = vr[frame_index]
img, image_size=input_size, use_thumbnail=True, max_num=max_num if use_gpu:
img = frame.cuda().permute(2, 0, 1).float() / 255.0
else:
img_np = frame.asnumpy()
img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
mean = img.mean(dim=[1, 2], keepdim=True)
# Prevent division by zero; clamp to minimum value of 1e-6
std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
img = (img - mean) / std
tiles = InternVLImageProcessor.dynamic_preprocess(
img, image_size=input_size, max_num=max_num, use_thumbnail=True
) )
pixel_values = [transform(tile) for tile in img]
pixel_values = torch.stack(pixel_values) pixel_values_list.append(tiles)
num_patches_list.append(pixel_values.shape[0]) num_patches_list.append(tiles.shape[0])
pixel_values_list.append(pixel_values)
pixel_values = torch.cat(pixel_values_list) pixel_values = torch.cat(pixel_values_list, dim=0)
return pixel_values, num_patches_list return pixel_values, num_patches_list
@staticmethod
def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
C, H, W = tensor.shape
aspect_ratio = W / H
# Generate all possible aspect ratios
target_ratios = set(
(i, j)
for n in range(1, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# Find closest ratio
best_ratio_diff = float("inf")
best_ratio = (1, 1)
for x, y in target_ratios:
target_ar = x / y
diff = abs(aspect_ratio - target_ar)
blocks = x * y
best_blocks = best_ratio[0] * best_ratio[1]
if diff < best_ratio_diff:
best_ratio_diff = diff
best_ratio = (x, y)
elif diff == best_ratio_diff and blocks > best_blocks:
best_ratio = (x, y)
target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
blocks = best_ratio[0] * best_ratio[1]
# Resize on GPU
resized = torch.nn.functional.interpolate(
tensor.unsqueeze(0),
size=(target_h, target_w),
mode="bicubic",
align_corners=False,
).squeeze(0)
# Split into tiles
tiles = []
for i in range(blocks):
x = (i % best_ratio[0]) * image_size
y = (i // best_ratio[0]) * image_size
tile = resized[:, y : y + image_size, x : x + image_size]
tiles.append(tile)
# Add thumbnail if needed
if use_thumbnail and len(tiles) > 1:
thumb = torch.nn.functional.interpolate(
tensor.unsqueeze(0),
size=(image_size, image_size),
mode="bicubic",
align_corners=False,
).squeeze(0)
tiles.append(thumb)
return torch.stack(tiles).to(torch.bfloat16)
async def process_mm_data_async( async def process_mm_data_async(
self, image_data, input_text, request_obj, **kwargs self, image_data, input_text, request_obj, **kwargs
): ):
...@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
discard_alpha_channel=True, discard_alpha_channel=True,
) )
def process_image_internvl(image, input_size=448, max_num=12):
transform = InternVLImageProcessor.build_transform(input_size=input_size)
images = InternVLImageProcessor.dynamic_preprocess(
image, image_size=input_size, use_thumbnail=True, max_num=max_num
)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
num_patches_list = [] num_patches_list = []
pixel_values = [] pixel_values = []
# Process each input with allocated frames # Process each input with allocated frames
for image_index, (image) in enumerate(base_output.images): for image_index, image in enumerate(base_output.images):
try: try:
# TODO: video input # TODO: video input
raw_image = process_image_internvl(image) # Convert PIL to GPU tensor
pixel_value = [raw_image.to(torch.bfloat16)] if isinstance(image, Image.Image):
pixel_values += pixel_value img_np = np.array(image.convert("RGB"))
num_patches = raw_image.shape[0] tensor = (
num_patches_list += [num_patches] torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
)
except FileNotFoundError as e: else:
print(e) tensor = image.cuda() # assume already tensor
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
mean = tensor.mean(dim=[1, 2], keepdim=True)
# Prevent division by zero; clamp to minimum value of 1e-6
std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
tensor = (tensor - mean) / std
tiles = self.dynamic_preprocess(
tensor, image_size=448, max_num=12, use_thumbnail=True
)
pixel_values.append(tiles)
num_patches_list.append(tiles.shape[0])
except Exception as e:
print(f"[Error] Failed to process image {image_index}: {e}")
return None return None
# Concatenate all
pixel_values = torch.cat(pixel_values, dim=0) pixel_values = torch.cat(pixel_values, dim=0)
original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>" original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder) input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
for idx, num_patches in enumerate(num_patches_list): input_text_updated = input_text
for num_patches in num_patches_list:
image_tokens = ( image_tokens = (
self.IMG_START_TOKEN self.IMG_START_TOKEN
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
+ self.IMG_END_TOKEN + self.IMG_END_TOKEN
) )
input_text = input_text.replace(original_placeholder, image_tokens, 1) input_text_updated = input_text_updated.replace(
original_placeholder, image_tokens, 1
)
input_text = input_text.replace(original_placeholder, self.IMG_CONTEXT_TOKEN) input_text_updated = input_text_updated.replace(
original_placeholder, self.IMG_CONTEXT_TOKEN
)
input_ids = self.tokenizer(input_text, return_tensors="pt")[ # Tokenize
input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
"input_ids" "input_ids"
].flatten() ].flatten()
input_ids = input_ids_tensor.tolist()
# Get image token offsets
image_offsets = self.get_mm_items_offset( image_offsets = self.get_mm_items_offset(
input_ids=input_ids, input_ids=input_ids_tensor.to("cuda"),
mm_token_id=self.mm_tokens.image_token_id, mm_token_id=self.mm_tokens.image_token_id,
) )
items = [ items = [
MultimodalDataItem( MultimodalDataItem(
feature=pixel_values, feature=pixel_values,
...@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ...@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
] ]
return { return {
"input_ids": input_ids.tolist(), "input_ids": input_ids,
"mm_items": items, "mm_items": items,
"im_start_id": self.img_start_token_id, "im_start_id": self.img_start_token_id,
"im_end_id": self.img_end_token_id, "im_end_id": self.img_end_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment