Unverified Commit d1ca7df8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processor for InternVL-based models (#12553)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
parent 96b23621
......@@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
And thus, we can override the method as:
```python
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```
......
......@@ -726,7 +726,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
*
* ✅︎
*
* \*
- * `Idefics3ForConditionalGeneration`
* Idefics3
* T + I
......@@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `NVLM_D_Model`
* NVLM-D 1.0
* T + I<sup>E+</sup>
* T + I<sup>+</sup>
* `nvidia/NVLM-D-72B`, etc.
*
* ✅︎
......@@ -859,7 +859,11 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{note}
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
:::
:::{note}
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import pytest
import torch
from PIL.Image import Image
from transformers import AutoConfig
# Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper)
from vllm.multimodal.image import rescale_image_size
models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
"h2oai/h2ovl-mississippi-2b",
]
def run_preprocessing_test(
image: Image,
config,
max_dynamic_patch: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Test the image preprocessing and calculate expected blocks."""
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
width, height = image.size
use_MSAC = config.use_msac
# Create the mapper function with the provided configuration
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
pixel_values = mapper(image)
# Calculate the expected number of blocks
if use_MSAC:
# First pass
blocks1, _, _, aspect_ratio = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
prior_aspect_ratio=None,
)
# Second pass
blocks2, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=aspect_ratio,
)
# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0
# Total blocks is the sum of blocks from both passes minus overlapping
total_blocks = blocks1 + blocks2 - 1
expected_blocks = total_blocks
else:
blocks, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=None,
)
expected_blocks = blocks
if config.use_thumbnail and expected_blocks > 1:
expected_blocks += 1
return pixel_values, expected_blocks
@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
def test_image_preprocessing(image_assets, model_name, size_factors,
max_dynamic_patch):
"""Test image preprocessing pipeline with different configurations."""
# Load the configuration from the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
for asset in image_assets:
image = asset.pil_image
for factor in size_factors:
scaled_image = rescale_image_size(image, factor)
# Test preprocessing and get expected number of blocks
pixel_values, expected_blocks = run_preprocessing_test(
scaled_image, config, max_dynamic_patch)
# Verify output shapes and properties
actual_blocks = pixel_values.shape[0]
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks, got {actual_blocks}")
# Check image dimensions
expected_size = (
3, # Number of channels (C, H, W)
config.vision_config.image_size,
config.vision_config.image_size,
)
for img in pixel_values:
assert img.shape == expected_size, (
f"Expected image size {expected_size}, got {img.shape}")
......@@ -250,6 +250,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
num_logprobs=10,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
......@@ -282,7 +283,6 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
......
......@@ -334,12 +334,12 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.use_msac = self.config.use_msac
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
......@@ -348,18 +348,19 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl)
# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image,
self.image_size,
self.min_num,
self.max_num,
self.use_thumbnail,
use_MSAC=self.config.use_msac).to(
self.dtype) for image in images
image_to_pixel_values_h2ovl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
use_msac=self.use_msac,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
......@@ -394,7 +395,6 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
......@@ -407,13 +407,17 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image, self.image_size, self.min_num,
self.max_num,
self.use_thumbnail).to(self.dtype)
for image in images
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
......@@ -448,7 +452,8 @@ def _internvl_generate(
) -> torch.LongTensor:
"""Generate method for InternVL2 model without fixed use_cache."""
assert self.img_context_token_id is not None
vit_embeds = self.extract_feature(pixel_values)
target_dtype = next(self.parameters()).dtype
vit_embeds = self.extract_feature(pixel_values.to(target_dtype))
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
......
......@@ -141,13 +141,14 @@ def _test_processing_correctness(
# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
......@@ -156,6 +157,7 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for H2OVL's multimodal preprocessing kwargs."""
from typing import Optional
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
])
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
@pytest.mark.parametrize("dynamic_image_size", [True, False])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
size_factors: list[int],
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
num_imgs: int,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
config = processor.info.get_hf_config()
use_msac = config.use_msac
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
min_num = config.min_dynamic_patch
max_num = max_dynamic_patch if dynamic_image_size else 1
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
for asset in image_assets:
for factor in size_factors:
image = rescale_image_size(asset.pil_image, factor)
mm_data = {"image": [image] * num_imgs}
width, height = image.size
# Calculate the expected number of blocks
if num_imgs == 1 and use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)
# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0
# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1
expected_num_patches = total_blocks
else:
blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks
if config.use_thumbnail and expected_num_patches != 1:
expected_num_patches += 1
processed_inputs = processor.apply(prompt, mm_data,
mm_processor_kwargs)
pixel_shape = (
processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)
assert pixel_shape[0] == expected_num_patches * num_imgs
# SPDX-License-Identifier: Apache-2.0
"""Tests for InternVL's multimodal preprocessing kwargs."""
from typing import Callable, Optional
from typing import Optional
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
models = ["OpenGVLab/InternVL2-2B"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.input_processor
@pytest.fixture()
def dummy_data_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.dummy_data
@pytest.fixture()
def get_max_internvl_image_tokens():
from vllm.model_executor.models.internvl import (
get_max_internvl_image_tokens)
return get_max_internvl_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_input_mapper_override(
model: str,
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
):
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert vllm_result["pixel_values"].size(1) == expected_num_patches
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_max_tokens_override(
get_max_internvl_image_tokens: Callable,
model: str,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches
actual_max_tokens = get_max_internvl_image_tokens(
ctx=InputContext(ctx.model_config),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_dummy_data_override(
dummy_data_for_internvl: Callable,
model: str,
num_imgs: int,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure dummy_data_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches
dummy_data = dummy_data_for_internvl(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
sequence_data = dummy_data.seq_data
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]
# Ensure we have the right number of placeholders per size
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
assert img_tok_count == expected_max_tokens * num_imgs
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
mm_data = {"image": [image] * num_imgs}
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_input_processor_override(
input_processor_for_internvl: Callable,
image_assets: _ImageAssets,
model: str,
num_imgs: int,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
):
"""Ensure input_processor_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
expected_toks_per_img = 256 * expected_num_patches
# Build the image str / prompt based on the number of images we pass
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = placeholders
images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
processed_inputs = input_processor_for_internvl(
ctx,
inputs,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
assert img_tok_count == 256 * expected_num_patches * num_imgs
assert pixel_shape[0] == expected_num_patches * num_imgs
......@@ -43,7 +43,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
info = processor.info
......@@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
......@@ -173,7 +179,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
seen_aspect_ratios = set[float]()
......
......@@ -44,7 +44,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
info = processor.info
......@@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
......@@ -174,7 +180,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
seen_aspect_ratios = set[float]()
......
......@@ -38,7 +38,10 @@ def test_processor_override(
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
......
......@@ -33,7 +33,10 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
......
......@@ -399,7 +399,11 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
......
......@@ -407,7 +407,11 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
......
......@@ -64,7 +64,11 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
......
......@@ -165,7 +165,11 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width=x[1], image_height=x[0]))
return ImageSize(width=width, height=height)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,
......
......@@ -80,7 +80,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(
......
This diff is collapsed.
This diff is collapsed.
......@@ -125,7 +125,11 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment