Unverified Commit 34ef6c81 authored by Mick's avatar Mick Committed by GitHub
Browse files

[VLM] Adopt fast image processor by default (#5065)

parent 61172091
......@@ -89,5 +89,4 @@ if __name__ == "__main__":
EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args()
eval_mmmu(args)
......@@ -7,6 +7,7 @@ import os
import pprint
import random
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Optional
import numpy as np
......@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs):
# merge all dataset
dataset = concatenate_datasets(sub_dataset_list)
## prepare images
samples = []
skip_count = 0
# use image file as input to ensure the consistency between sglang and hf
# Prepare images in parallel
images_path = os.path.expanduser("~/.cache/mmmu/images")
os.makedirs(images_path, exist_ok=True)
print(f"Saving images to: {images_path}")
for i, sample in enumerate(tqdm(dataset)):
samples = []
skip_count = 0
def process_sample(i, sample):
sample = process_single_sample(sample)
sample = construct_prompt(sample, eval_args.config)
image = sample["image"]
width, height = image.size
if width * height >= eval_args.image_pixels_limit:
skip_count += 1
continue
return None, True
image_path = f"{images_path}/image_{i}.png"
if not os.path.exists(image_path):
image.save(image_path)
sample["image_path"] = image_path
samples.append(sample)
return sample, False
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(process_sample, i, sample)
for i, sample in enumerate(dataset)
]
for future in tqdm(as_completed(futures), total=len(futures)):
sample, skipped = future.result()
if skipped:
skip_count += 1
elif sample:
samples.append(sample)
print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
......
......@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
Please consult the documentation below to learn more about the parameters you may provide when launching a server.
## Model and tokenizer
## Model, processor and tokenizer
* `model_path`: Path to the model that will be served.
* `tokenizer_path`: Defaults to the `model_path`.
......@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
* `json_model_override_args`: Override model config with the provided JSON.
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor
## Serving: HTTP & API
......
......@@ -215,6 +215,7 @@ def get_processor(
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None,
use_fast: Optional[bool] = True,
**kwargs,
):
# pop 'revision' from kwargs if present.
......@@ -232,6 +233,9 @@ def get_processor(
if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
processor = AutoProcessor.from_pretrained(
tokenizer_name,
*args,
......
......@@ -4,14 +4,16 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import List, Optional
import numpy as np
import PIL
from decord import VideoReader, cpu
from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.utils import encode_video, load_audio, load_image, logger
from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.utils import encode_video, load_audio, load_image
@dataclasses.dataclass
......@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audios"] = audios
processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessorFast
):
kwargs["device"] = "cuda"
result = processor.__call__(
text=[input_text],
padding=True,
......@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True,
):
"""
load multimodal data parallelly
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
return futures, task_info
def load_mm_data(
self,
prompt: str,
......@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
futures, task_info = self.submit_data_loading_tasks(
text_parts=text_parts,
multimodal_tokens=multimodal_tokens,
image_data=image_data,
audio_data=audio_data,
discard_alpha_channel=discard_alpha_channel,
)
# Process results
image_sizes, images, audios = [], [], []
new_text = ""
for index, text_part in enumerate(text_parts):
try:
if text_part == multimodal_tokens.image_token:
# load as image
if len(images) >= MAX_NUM_FRAMES:
frames_to_process = 0
else:
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max(
1, int(estimated_frames * scaling_factor)
)
if frames_to_process == 0:
frames = []
else:
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += frames[0].size * len(frames)
# Generate a hashable value for the image file
if isinstance(image_file, Image.Image):
# For PIL.Image objects, use the ID as a hashable value
hash_value = hash(id(image_file))
else:
# For other types (strings, etc.), use the regular hash
hash_value = hash(image_file)
hashes += [hash_value] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
task_ptr = 0
for text_part in text_parts:
if text_part in multimodal_tokens.collect():
task_type, data, frame_limit = task_info[task_ptr]
result = futures[task_ptr].result()
task_ptr += 1
if task_type == Modality.IMAGE:
frames = [result] if not isinstance(result, list) else result
if frames:
image_sizes += frames[0].size * len(frames)
images += frames
new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames)
elif text_part == multimodal_tokens.audio_token:
# load as audio
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
elif task_type == Modality.AUDIO:
# audio
audios.append(result)
new_text += multimodal_tokens.audio_token
else:
# TODO(mick): handle video
# normal text
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise RuntimeError(f"An exception occurred while loading images: {e}")
# TODO: handle video
else:
new_text += text_part
out = BaseMultiModalProcessorOutput(
images=images,
......
......@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
base_out = self.load_mm_data(
prompt=input_ids,
image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token
),
max_req_input_len=max_req_input_len,
)
......
......@@ -222,10 +222,10 @@ class MultimodalDataItem:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda:
tensor_cpu = torch.frombuffer(
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
).clone()
# TODO: improve this
tensor_cpu = tensor.cpu()
else:
tensor_cpu = tensor
......@@ -321,7 +321,6 @@ class MultimodalInputs:
item.set_pad_value()
optional_args = [
"modalities",
"im_token_id",
"im_start_id",
"im_end_id",
......
......@@ -452,6 +452,7 @@ class Scheduler(
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
self.tokenizer = self.processor.tokenizer
else:
......
......@@ -180,6 +180,7 @@ class TokenizerManager:
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
)
# We want to parallelize the image pre-processing so we create an executor for it
......
......@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if is_mrope_enabled:
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
......
......@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix=add_prefix("lm_head", prefix),
)
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it)
"""
is_mrope_enabled = "mrope_section" in self.config.rope_scaling
if is_mrope_enabled:
if self.is_mrope_enabled:
positions = forward_batch.mrope_positions
if not (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
if is_mrope_enabled:
if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
......
......@@ -196,6 +196,9 @@ class ServerArgs:
disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
......@@ -979,6 +982,7 @@ class ServerArgs:
)
parser.add_argument(
"--enable-llama4-multimodal",
default=ServerArgs.enable_llama4_multimodal,
action="store_true",
help="Enable the multimodal functionality for Llama-4.",
)
......@@ -1170,6 +1174,13 @@ class ServerArgs:
help="Bootstrap server port on the prefill server. Default is 8998.",
)
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment