Unverified Commit 34ef6c81 authored by Mick's avatar Mick Committed by GitHub
Browse files

[VLM] Adopt fast image processor by default (#5065)

parent 61172091
...@@ -89,5 +89,4 @@ if __name__ == "__main__": ...@@ -89,5 +89,4 @@ if __name__ == "__main__":
EvalArgs.add_cli_args(parser) EvalArgs.add_cli_args(parser)
args = add_common_sglang_args_and_parse(parser) args = add_common_sglang_args_and_parse(parser)
args = parser.parse_args() args = parser.parse_args()
eval_mmmu(args) eval_mmmu(args)
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
import pprint import pprint
import random import random
import re import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Optional from typing import Dict, Optional
import numpy as np import numpy as np
...@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs): ...@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs):
# merge all dataset # merge all dataset
dataset = concatenate_datasets(sub_dataset_list) dataset = concatenate_datasets(sub_dataset_list)
## prepare images # Prepare images in parallel
samples = []
skip_count = 0
# use image file as input to ensure the consistency between sglang and hf
images_path = os.path.expanduser("~/.cache/mmmu/images") images_path = os.path.expanduser("~/.cache/mmmu/images")
os.makedirs(images_path, exist_ok=True) os.makedirs(images_path, exist_ok=True)
print(f"Saving images to: {images_path}") print(f"Saving images to: {images_path}")
for i, sample in enumerate(tqdm(dataset)): samples = []
skip_count = 0
def process_sample(i, sample):
sample = process_single_sample(sample) sample = process_single_sample(sample)
sample = construct_prompt(sample, eval_args.config) sample = construct_prompt(sample, eval_args.config)
image = sample["image"] image = sample["image"]
width, height = image.size width, height = image.size
if width * height >= eval_args.image_pixels_limit: if width * height >= eval_args.image_pixels_limit:
skip_count += 1 return None, True
continue
image_path = f"{images_path}/image_{i}.png" image_path = f"{images_path}/image_{i}.png"
if not os.path.exists(image_path): if not os.path.exists(image_path):
image.save(image_path) image.save(image_path)
sample["image_path"] = image_path sample["image_path"] = image_path
samples.append(sample) return sample, False
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(process_sample, i, sample)
for i, sample in enumerate(dataset)
]
for future in tqdm(as_completed(futures), total=len(futures)):
sample, skipped = future.result()
if skipped:
skip_count += 1
elif sample:
samples.append(sample)
print( print(
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset" f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
......
...@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct ...@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
Please consult the documentation below to learn more about the parameters you may provide when launching a server. Please consult the documentation below to learn more about the parameters you may provide when launching a server.
## Model and tokenizer ## Model, processor and tokenizer
* `model_path`: Path to the model that will be served. * `model_path`: Path to the model that will be served.
* `tokenizer_path`: Defaults to the `model_path`. * `tokenizer_path`: Defaults to the `model_path`.
...@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/). * `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
* `json_model_override_args`: Override model config with the provided JSON. * `json_model_override_args`: Override model config with the provided JSON.
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. * `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor
## Serving: HTTP & API ## Serving: HTTP & API
......
...@@ -215,6 +215,7 @@ def get_processor( ...@@ -215,6 +215,7 @@ def get_processor(
tokenizer_mode: str = "auto", tokenizer_mode: str = "auto",
trust_remote_code: bool = False, trust_remote_code: bool = False,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
use_fast: Optional[bool] = True,
**kwargs, **kwargs,
): ):
# pop 'revision' from kwargs if present. # pop 'revision' from kwargs if present.
...@@ -232,6 +233,9 @@ def get_processor( ...@@ -232,6 +233,9 @@ def get_processor(
if "size" not in kwargs: if "size" not in kwargs:
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520} kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
if config.model_type not in {"llava", "clip"}:
kwargs["use_fast"] = use_fast
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
tokenizer_name, tokenizer_name,
*args, *args,
......
...@@ -4,14 +4,16 @@ import dataclasses ...@@ -4,14 +4,16 @@ import dataclasses
import multiprocessing as mp import multiprocessing as mp
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import PIL import PIL
from decord import VideoReader, cpu from decord import VideoReader, cpu
from PIL import Image from PIL import Image
from transformers import BaseImageProcessorFast
from sglang.srt.utils import encode_video, load_audio, load_image, logger from sglang.srt.managers.schedule_batch import Modality
from sglang.srt.utils import encode_video, load_audio, load_image
@dataclasses.dataclass @dataclasses.dataclass
...@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC): ...@@ -78,6 +80,10 @@ class BaseMultimodalProcessor(ABC):
kwargs["audios"] = audios kwargs["audios"] = audios
processor = self._processor processor = self._processor
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessorFast
):
kwargs["device"] = "cuda"
result = processor.__call__( result = processor.__call__(
text=[input_text], text=[input_text],
padding=True, padding=True,
...@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC): ...@@ -111,6 +117,84 @@ class BaseMultimodalProcessor(ABC):
return estimated_frames_list return estimated_frames_list
@staticmethod
def _load_single_item(
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
):
"""Static method that can be pickled for multiprocessing"""
try:
if is_audio:
return load_audio(data)
elif is_video:
path = data[len("video:") :]
return encode_video(path, frame_count_limit)
else:
img, _ = load_image(data)
return img.convert("RGB") if discard_alpha_channel else img
except Exception as e:
raise RuntimeError(f"Error while loading data {data}: {e}")
def submit_data_loading_tasks(
self,
text_parts: List[str],
multimodal_tokens: MultimodalSpecialTokens,
image_data: Optional[list] = None,
audio_data: Optional[list] = None,
discard_alpha_channel: bool = True,
):
"""
load multimodal data parallelly
"""
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES = 30
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
total_frame_count = sum(estimated_frames_list)
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
assert len(image_data) == len(estimated_frames_list)
# Submit all tasks
futures = []
task_info = []
image_index, audio_index = 0, 0
for text_part in text_parts:
if text_part == multimodal_tokens.image_token:
data = image_data[image_index]
is_video = isinstance(data, str) and data.startswith("video:")
estimated_frames = estimated_frames_list[image_index]
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
is_video,
False,
frame_count_limit,
discard_alpha_channel,
)
)
task_info.append((Modality.IMAGE, data, frame_count_limit))
image_index += 1
elif text_part == multimodal_tokens.audio_token:
data = audio_data[audio_index]
futures.append(
self.io_executor.submit(
BaseMultimodalProcessor._load_single_item,
data,
False,
True,
None,
discard_alpha_channel,
)
)
task_info.append((Modality.AUDIO, data, None))
audio_index += 1
return futures, task_info
def load_mm_data( def load_mm_data(
self, self,
prompt: str, prompt: str,
...@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC): ...@@ -155,84 +239,37 @@ class BaseMultimodalProcessor(ABC):
# split text into list of normal text and special tokens # split text into list of normal text and special tokens
text_parts = re.split(pattern, prompt) text_parts = re.split(pattern, prompt)
# TODO(mick): load from server_args, env, or sampling_params futures, task_info = self.submit_data_loading_tasks(
MAX_NUM_FRAMES = 30 text_parts=text_parts,
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) multimodal_tokens=multimodal_tokens,
total_frame_count = sum(estimated_frames_list) image_data=image_data,
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. audio_data=audio_data,
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used discard_alpha_channel=discard_alpha_channel,
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) )
# Process results
assert len(image_data) == len(estimated_frames_list) image_sizes, images, audios = [], [], []
image_index, audio_index = 0, 0
hashes, image_sizes, images, audios = [], [], [], []
new_text = "" new_text = ""
for index, text_part in enumerate(text_parts): task_ptr = 0
try:
if text_part == multimodal_tokens.image_token: for text_part in text_parts:
# load as image if text_part in multimodal_tokens.collect():
if len(images) >= MAX_NUM_FRAMES: task_type, data, frame_limit = task_info[task_ptr]
frames_to_process = 0 result = futures[task_ptr].result()
else: task_ptr += 1
estimated_frames = estimated_frames_list[image_index]
frames_to_process = max( if task_type == Modality.IMAGE:
1, int(estimated_frames * scaling_factor) frames = [result] if not isinstance(result, list) else result
) if frames:
image_sizes += frames[0].size * len(frames)
if frames_to_process == 0: images += frames
frames = []
else:
image_file = image_data[image_index]
if isinstance(image_file, str) and image_file.startswith(
"video:"
):
# video
path = image_file[len("video:") :]
frames = encode_video(
path, frame_count_limit=frames_to_process
)
else:
# image
raw_image, _size = load_image(image_file)
if discard_alpha_channel:
raw_image = raw_image.convert("RGB")
frames = [raw_image]
if len(frames) == 0:
continue
image_sizes += frames[0].size * len(frames)
# Generate a hashable value for the image file
if isinstance(image_file, Image.Image):
# For PIL.Image objects, use the ID as a hashable value
hash_value = hash(id(image_file))
else:
# For other types (strings, etc.), use the regular hash
hash_value = hash(image_file)
hashes += [hash_value] * len(frames)
images += frames
image_index += 1
if frames_to_process != 0:
new_text += multimodal_tokens.image_token * len(frames) new_text += multimodal_tokens.image_token * len(frames)
assert frames_to_process == len(frames) elif task_type == Modality.AUDIO:
elif text_part == multimodal_tokens.audio_token: # audio
# load as audio audios.append(result)
audio_file = audio_data[audio_index]
audio = load_audio(audio_file)
hashes += [hash(audio_file)]
audios += [audio]
audio_index += 1
new_text += multimodal_tokens.audio_token new_text += multimodal_tokens.audio_token
else: # TODO: handle video
# TODO(mick): handle video else:
# normal text new_text += text_part
new_text += text_part
except Exception as e:
logger.error(f"An exception occurred while loading images: {e}")
raise RuntimeError(f"An exception occurred while loading images: {e}")
out = BaseMultiModalProcessorOutput( out = BaseMultiModalProcessorOutput(
images=images, images=images,
......
...@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor): ...@@ -33,7 +33,9 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
base_out = self.load_mm_data( base_out = self.load_mm_data(
prompt=input_ids, prompt=input_ids,
image_data=image_data, image_data=image_data,
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag), multimodal_tokens=MultimodalSpecialTokens(
image_token=processor.image_token
),
max_req_input_len=max_req_input_len, max_req_input_len=max_req_input_len,
) )
......
...@@ -222,10 +222,10 @@ class MultimodalDataItem: ...@@ -222,10 +222,10 @@ class MultimodalDataItem:
# memoryview() doesn't support PyTorch's BFloat16 dtype # memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float() tensor = tensor.float()
assert isinstance(tensor, torch.Tensor)
if tensor.is_cuda: if tensor.is_cuda:
tensor_cpu = torch.frombuffer( # TODO: improve this
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel() tensor_cpu = tensor.cpu()
).clone()
else: else:
tensor_cpu = tensor tensor_cpu = tensor
...@@ -321,7 +321,6 @@ class MultimodalInputs: ...@@ -321,7 +321,6 @@ class MultimodalInputs:
item.set_pad_value() item.set_pad_value()
optional_args = [ optional_args = [
"modalities",
"im_token_id", "im_token_id",
"im_start_id", "im_start_id",
"im_end_id", "im_end_id",
......
...@@ -452,6 +452,7 @@ class Scheduler( ...@@ -452,6 +452,7 @@ class Scheduler(
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
) )
self.tokenizer = self.processor.tokenizer self.tokenizer = self.processor.tokenizer
else: else:
......
...@@ -180,6 +180,7 @@ class TokenizerManager: ...@@ -180,6 +180,7 @@ class TokenizerManager:
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code, trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision, revision=server_args.revision,
use_fast=not server_args.disable_fast_image_processor,
) )
# We want to parallelize the image pre-processing so we create an executor for it # We want to parallelize the image pre-processing so we create an executor for it
......
...@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -462,6 +462,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -515,15 +516,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
is_mrope_enabled = "mrope_section" in self.config.rope_scaling if self.is_mrope_enabled:
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if is_mrope_enabled: if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
......
...@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -467,6 +467,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
...@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -521,15 +522,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
otherwise it will be `(seq_len,). otherwise it will be `(seq_len,).
(Use input_metadata.mrope_positions to replace it) (Use input_metadata.mrope_positions to replace it)
""" """
is_mrope_enabled = "mrope_section" in self.config.rope_scaling if self.is_mrope_enabled:
if is_mrope_enabled:
positions = forward_batch.mrope_positions positions = forward_batch.mrope_positions
if not ( if not (
forward_batch.forward_mode.is_decode() forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs() or not forward_batch.contains_image_inputs()
): ):
if is_mrope_enabled: if self.is_mrope_enabled:
assert positions.ndim == 2 and positions.size(0) == 3, ( assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires " "multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}" f"(3, seq_len) positions, but got {positions.size()}"
......
...@@ -196,6 +196,9 @@ class ServerArgs: ...@@ -196,6 +196,9 @@ class ServerArgs:
disaggregation_mode: str = "null" disaggregation_mode: str = "null"
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
if self.enable_ep_moe: if self.enable_ep_moe:
...@@ -979,6 +982,7 @@ class ServerArgs: ...@@ -979,6 +982,7 @@ class ServerArgs:
) )
parser.add_argument( parser.add_argument(
"--enable-llama4-multimodal", "--enable-llama4-multimodal",
default=ServerArgs.enable_llama4_multimodal,
action="store_true", action="store_true",
help="Enable the multimodal functionality for Llama-4.", help="Enable the multimodal functionality for Llama-4.",
) )
...@@ -1170,6 +1174,13 @@ class ServerArgs: ...@@ -1170,6 +1174,13 @@ class ServerArgs:
help="Bootstrap server port on the prefill server. Default is 8998.", help="Bootstrap server port on the prefill server. Default is 8998.",
) )
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment