"vscode:/vscode.git/clone" did not exist on "d78789ac16870809d64378105f200049cae95112"
Unverified Commit eed11ebe authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-OneVision (#11717)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 300acb83
import pytest
from PIL import Image
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor
# FIXME: image_size [(198, 176), (176, 198)]
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_next,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_next(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
import pytest
from PIL import Image
from transformers import AutoTokenizer
from vllm.inputs import InputProcessingContext
from ....utils import build_model_context
# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor
@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_onevision,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}
# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_onevision(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})
image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs
first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
"""Tests for phi3v's multimodal preprocessing kwargs.""" """Tests for phi3v's multimodal preprocessing kwargs."""
from typing import Optional
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -10,8 +8,6 @@ from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID ...@@ -10,8 +8,6 @@ from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection # Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture() @pytest.fixture()
...@@ -20,40 +16,40 @@ def processor_for_phi3v(): ...@@ -20,40 +16,40 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor return Phi3VMultiModalProcessor
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"num_crops,expected_toks_per_img", ("mm_processor_kwargs", "expected_toks_per_img"),
[ [
(4, 757), ({"num_crops": 4}, 757),
(16, 1921), ({"num_crops": 16}, 1921),
# the default num_crops of phi-3.5-vision is 4 # the default num_crops of phi-3.5-vision is 4
(None, 757), ({}, 757),
]) ])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, def test_processor_override(
model: str, num_crops: Optional[int], processor_for_phi3v,
expected_toks_per_img: int, num_imgs: int): image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly.""" """Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context( ctx = build_model_context(
model_name=model, model_name=model_id,
tokenizer_name=model, tokenizer_name=model_id,
trust_remote_code=True, trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
mm_data = {"image": images}
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}
processor = processor_for_phi3v(ctx) processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
......
from typing import Any, Dict, Tuple
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -8,56 +6,45 @@ from vllm.inputs import InputProcessingContext ...@@ -8,56 +6,45 @@ from vllm.inputs import InputProcessingContext
from .....conftest import _ImageAssets from .....conftest import _ImageAssets
from ....utils import build_model_context from ....utils import build_model_context
MODEL = "Qwen/Qwen2-VL-2B-Instruct"
MIN_PIXELS = "min_pixels"
MAX_PIXELS = "max_pixels"
# Fixtures lazy import to avoid initializing CUDA during test collection # Fixtures lazy import to avoid initializing CUDA during test collection
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers.
@pytest.fixture() @pytest.fixture()
def processor_for_qwen2_vl(): def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor return Qwen2VLMultiModalProcessor
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [
({}, 1426, (5704, 1176)), ({}, 1426, (5704, 1176)),
({ ({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)),
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330, (1320, 1176)),
]) ])
@pytest.mark.parametrize("model", [MODEL]) # yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2]) @pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override( def test_processor_override(
processor_for_qwen2_vl, processor_for_qwen2_vl,
image_assets: _ImageAssets, image_assets: _ImageAssets,
model: str, model_id: str,
mm_processor_kwargs: Dict[str, Any], mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int, expected_toks_per_img: int,
expected_pixels_shape: Tuple[int, int], expected_pixels_shape: tuple[int, int],
num_imgs: int, num_imgs: int,
): ):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly.""" """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context( ctx = build_model_context(
model_name=model, model_name=model_id,
tokenizer_name=model, tokenizer_name=model_id,
mm_processor_kwargs=None, mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs}, limit_mm_per_prompt={"image": num_imgs},
) )
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer) ctx = InputProcessingContext(ctx.model_config, tokenizer)
# Build the image str / prompt based on the number of images we pass # Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs mm_data = {"image": [image_assets[0].pil_image] * num_imgs}
mm_data = {"image": images}
processor = processor_for_qwen2_vl(ctx) processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
......
...@@ -274,10 +274,8 @@ VLM_TEST_SETTINGS = { ...@@ -274,10 +274,8 @@ VLM_TEST_SETTINGS = {
), ),
limit_mm_per_prompt={"image": 4}, limit_mm_per_prompt={"image": 4},
)], )],
# Llava-next tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
), ),
"llava_one_vision": VLMTestInfo( "llava_onevision": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
...@@ -288,8 +286,6 @@ VLM_TEST_SETTINGS = { ...@@ -288,8 +286,6 @@ VLM_TEST_SETTINGS = {
), ),
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
# Llava-one-vision tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
custom_test_opts=[CustomTestOptions( custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
...@@ -306,7 +302,6 @@ VLM_TEST_SETTINGS = { ...@@ -306,7 +302,6 @@ VLM_TEST_SETTINGS = {
max_model_len=4096, max_model_len=4096,
auto_cls=AutoModelForVision2Seq, auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
), ),
"mantis": VLMTestInfo( "mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"], models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
...@@ -431,7 +426,7 @@ VLM_TEST_SETTINGS = { ...@@ -431,7 +426,7 @@ VLM_TEST_SETTINGS = {
) for inp in custom_inputs.different_patch_input_cases_internvl() ) for inp in custom_inputs.different_patch_input_cases_internvl()
], ],
), ),
"llava_one_vision-multiple-images": VLMTestInfo( "llava_onevision-multiple-images": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS, test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384, max_model_len=16384,
......
...@@ -427,130 +427,3 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, ...@@ -427,130 +427,3 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
mm_limit=1, mm_limit=1,
tensor_parallel_size=1, tensor_parallel_size=1,
) )
def run_chunked_prefill_test(
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
mm_limit: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Compare inference result between
chunked prefill disabled and chunked prefill enabled
"""
# NOTE:
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend
) as vllm_model:
outputs_per_case = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None)
for prompts, images, videos in inputs
]
with vllm_runner(
model,
task="generate",
max_model_len=4000,
max_num_seqs=4,
dtype=dtype,
limit_mm_per_prompt={
"image": mm_limit,
"video": mm_limit
},
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enable_chunked_prefill=True,
# should be small enough to ensure prefilling is chunked
max_num_batched_tokens=32,
mm_processor_kwargs={
"max_pixels": 16 * 28 * 28,
}) as vllm_model_chunked:
outputs_per_case_chunked = [
vllm_model_chunked.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images or None,
videos=videos or None) for prompts, images, videos in inputs
]
for outputs, \
outputs_chunked \
in zip(outputs_per_case,
outputs_per_case_chunked):
check_logprobs_close(
outputs_0_lst=outputs,
outputs_1_lst=outputs_chunked,
name_0="non_chunked",
name_1="chunked",
)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [1])
@pytest.mark.parametrize("num_logprobs", [10])
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
model: str, dtype: str,
max_tokens: int,
num_logprobs: int) -> None:
"""
Test Qwen2-VL's chunked prefill with M-RoPE
"""
prompts = [
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
for prompt in example_prompts[:1]
]
# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
# so an image is included in the inputs
# 2. however, Qwen2-VL currently won't work properly
# when chunked prefill is enabled and there are some multi-modal inputs,
# here use a hacky way: provide a **zero-length** image to make it happy
#
# and finally we achieved:
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
zero_len_image = {
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
"image_grid_thw": torch.tensor([[0, 0, 0]])
}
images = [zero_len_image] * len(prompts)
inputs_per_case: List[Tuple[List[str], PromptImageInput,
PromptVideoInput]] = [
(prompts, images, []),
]
run_chunked_prefill_test(
vllm_runner,
inputs_per_case,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
mm_limit=1,
tensor_parallel_size=1,
)
...@@ -11,8 +11,8 @@ from vllm.config import ModelConfig ...@@ -11,8 +11,8 @@ from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
_PlaceholderInfo, find_text_matches, _PlaceholderInfo, find_mm_placeholders,
find_token_matches, iter_placeholders, find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches,
replace_text_matches, replace_text_matches,
replace_token_matches) replace_token_matches)
...@@ -314,21 +314,27 @@ def test_find_replace_text( ...@@ -314,21 +314,27 @@ def test_find_replace_text(
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
matches = find_text_matches(prompt, prompt_repls) mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_text_matches( result = replace_text_matches(
prompt, prompt,
matches, mm_matches,
{key: mm_count {key: mm_count
for key in repl_by_key}, for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("matches:", matches) print("mm_matches:", mm_matches)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
...@@ -380,21 +386,27 @@ def test_find_replace_tokens( ...@@ -380,21 +386,27 @@ def test_find_replace_tokens(
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items() for key, target in target_by_key.items()
] }
matches = find_token_matches(prompt, prompt_repls) mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
result = replace_token_matches( result = replace_token_matches(
prompt, prompt,
matches, mm_matches,
{key: mm_count {key: mm_count
for key in repl_by_key}, for key in repl_by_key},
) )
# Only displayed on error # Only displayed on error
print("matches:", matches) print("mm_matches:", mm_matches)
print("result:", result) print("result:", result)
# Manually constructed results # Manually constructed results
...@@ -417,58 +429,76 @@ def test_find_replace_tokens( ...@@ -417,58 +429,76 @@ def test_find_replace_tokens(
[ [
( (
[1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
[ {
_PlaceholderInfo( "pattern_1": [
modality="pattern_1", _PlaceholderInfo(
start_idx=6, modality="pattern_1",
replacement=[32000, 32000], item_idx=0,
), start_idx=6,
], replacement=[32000, 32000],
),
],
}
), ),
( (
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[ {
_PlaceholderInfo( "pattern_1": [
modality="pattern_1", _PlaceholderInfo(
start_idx=1, modality="pattern_1",
replacement=[32000, 32000], item_idx=0,
), start_idx=1,
_PlaceholderInfo( replacement=[32000, 32000],
modality="pattern_1", ),
start_idx=5, _PlaceholderInfo(
replacement=[32000, 32000], modality="pattern_1",
), item_idx=1,
_PlaceholderInfo( start_idx=5,
modality="pattern_3", replacement=[32000, 32000],
start_idx=7, ),
replacement=[1550, 918, 1550], ],
), "pattern_3": [
], _PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=7,
replacement=[1550, 918, 1550],
),
],
}
), ),
( (
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[ {
_PlaceholderInfo( "pattern_1": [
modality="pattern_1", _PlaceholderInfo(
start_idx=1, modality="pattern_1",
replacement=[32000, 32000], item_idx=0,
), start_idx=1,
_PlaceholderInfo( replacement=[32000, 32000],
modality="pattern_1", ),
start_idx=3, _PlaceholderInfo(
replacement=[32000, 32000], modality="pattern_1",
), item_idx=1,
_PlaceholderInfo( start_idx=3,
modality="pattern_3", replacement=[32000, 32000],
start_idx=6, ),
replacement=[1550, 918, 1550], ],
), "pattern_3": [
], _PlaceholderInfo(
modality="pattern_3",
item_idx=0,
start_idx=6,
replacement=[1550, 918, 1550],
),
],
}
), ),
] ]
) )
# yapf: enable # yapf: enable
def test_iter_placeholders( def test_find_mm_placeholders(
repl_by_key, repl_by_key,
prompt, prompt,
expected, expected,
...@@ -476,19 +506,18 @@ def test_iter_placeholders( ...@@ -476,19 +506,18 @@ def test_iter_placeholders(
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ mm_prompt_repls = {
PromptReplacement(key, [], repl).bind(mock_tokenizer) key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
] }
result = list( result = find_mm_placeholders(
iter_placeholders( mm_prompt_repls,
prompt_repls, prompt,
prompt, # Effectively match all occurrences in the prompt
# Effectively match all occurrences in the prompt {key: 3
{key: 3 for key in repl_by_key},
for key in repl_by_key}, )
))
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
...@@ -694,7 +723,10 @@ def _test_processing_cache_correctness( ...@@ -694,7 +723,10 @@ def _test_processing_cache_correctness(
} }
mm_counts = {k: len(vs) for k, vs in mm_data.items()} mm_counts = {k: len(vs) for k, vs in mm_data.items()}
prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text prompt = baseline_processor._get_dummy_processor_inputs(
model_config.max_model_len,
mm_counts,
).prompt_text
# Drop unnecessary keys and test single -> multi conversion # Drop unnecessary keys and test single -> multi conversion
if rng.rand() < simplify_rate: if rng.rand() < simplify_rate:
...@@ -728,6 +760,8 @@ def _test_processing_cache_correctness( ...@@ -728,6 +760,8 @@ def _test_processing_cache_correctness(
("adept/fuyu-8b", {"image": False}), ("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}), ("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
......
...@@ -456,7 +456,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -456,7 +456,7 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config() hf_config = self.ctx.get_hf_config()
return max(hf_config.projector_patch_to_query_dict.values()) return max(hf_config.projector_patch_to_query_dict.values())
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -488,8 +488,9 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -488,8 +488,9 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config() hf_config = self.ctx.get_hf_config()
......
...@@ -405,7 +405,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor): ...@@ -405,7 +405,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
hf_config = self.ctx.get_hf_config(Blip2Config) hf_config = self.ctx.get_hf_config(Blip2Config)
return hf_config.num_query_tokens return hf_config.num_query_tokens
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> Blip2Processor: def _get_hf_processor(self) -> Blip2Processor:
...@@ -457,8 +457,9 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor): ...@@ -457,8 +457,9 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(Blip2Config) hf_config = self.ctx.get_hf_config(Blip2Config)
......
...@@ -57,7 +57,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor): ...@@ -57,7 +57,7 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
processor = self._get_hf_processor() processor = self._get_hf_processor()
return processor.image_seq_length return processor.image_seq_length
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_num_image_tokens()} return {"image": self._get_num_image_tokens()}
def _get_hf_processor(self) -> ChameleonProcessor: def _get_hf_processor(self) -> ChameleonProcessor:
...@@ -90,8 +90,9 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor): ...@@ -90,8 +90,9 @@ class ChameleonMultiModalProcessor(BaseMultiModalProcessor):
) )
] ]
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
config = self.ctx.get_hf_config(ChameleonConfig) config = self.ctx.get_hf_config(ChameleonConfig)
......
...@@ -164,15 +164,18 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): ...@@ -164,15 +164,18 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config) return get_max_clip_image_tokens(self.vision_config)
def get_num_patches(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size
def get_patch_size(self) -> int:
return self.vision_config.patch_size
def get_patch_grid_length(self) -> int:
return get_clip_patch_grid_length( return get_clip_patch_grid_length(
image_size=self.vision_config.image_size, image_size=self.vision_config.image_size,
patch_size=self.vision_config.patch_size, patch_size=self.vision_config.patch_size,
) )
def get_image_size(self) -> int:
return self.vision_config.image_size
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class CLIPVisionEmbeddings(nn.Module): class CLIPVisionEmbeddings(nn.Module):
......
...@@ -96,7 +96,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor): ...@@ -96,7 +96,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
nrows = math.ceil(image_height / 30) nrows = math.ceil(image_height / 30)
return ncols, nrows return ncols, nrows
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
target_width, target_height = self._get_image_target_size() target_width, target_height = self._get_image_target_size()
max_ncols, max_nrows = self._get_image_feature_grid_size( max_ncols, max_nrows = self._get_image_feature_grid_size(
...@@ -208,8 +208,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor): ...@@ -208,8 +208,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor):
return result return result
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
target_width, target_height = self._get_image_target_size() target_width, target_height = self._get_image_target_size()
......
...@@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -25,11 +25,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize) ImageSize)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (InputProcessingContext,
InputProcessingContext,
MultiModalDataItems, ProcessingCache, MultiModalDataItems, ProcessingCache,
ProcessorInputs, PromptReplacement, ProcessorInputs, PromptReplacement)
full_groupby_modality)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel from .clip import CLIPVisionModel
...@@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel, ...@@ -39,7 +37,7 @@ from .pixtral import (PixtralHFVisionModel,
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import vision_encoder_info from .vision import BaseVisionLanguageMultiModalProcessor
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol): ...@@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol):
vision_feature_layer: Final[Union[int, List[int]]] vision_feature_layer: Final[Union[int, List[int]]]
class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
def __init__(self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(ctx,
cache=cache,
enable_sanity_checks=enable_sanity_checks)
vision_config = self._get_hf_config().vision_config
self._vision_encoder_info = vision_encoder_info(vision_config)
@abstractmethod @abstractmethod
def _get_hf_config(self) -> LlavaLikeConfig: def _get_hf_config(self) -> LlavaLikeConfig:
...@@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -121,6 +107,19 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _apply_feature_select_strategy( def _apply_feature_select_strategy(
self, self,
strategy: str, strategy: str,
...@@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -142,19 +141,6 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
self._vision_encoder_info.get_max_image_tokens(), self._vision_encoder_info.get_max_image_tokens(),
) )
def get_mm_max_tokens_per_item(self) -> Mapping[str, int]:
return {"image": self._get_max_image_tokens()}
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_dummy_image_size(self) -> ImageSize: def _get_dummy_image_size(self) -> ImageSize:
image_size = self._vision_encoder_info.get_image_size() image_size = self._vision_encoder_info.get_image_size()
return ImageSize(image_size, image_size) return ImageSize(image_size, image_size)
...@@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): ...@@ -163,8 +149,9 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor):
def _get_image_token(self) -> str: def _get_image_token(self) -> str:
raise NotImplementedError raise NotImplementedError
def _get_dummy_mm_inputs( def _get_dummy_processor_inputs(
self, self,
seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> ProcessorInputs: ) -> ProcessorInputs:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
...@@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -709,7 +696,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens "</Image>)", # 3 tokens
]) ])
mantis_repls = self._bind_prompt_replacements([ mantis_mm_repls = self._bind_and_group_repls([
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id] * num_image_tokens, target=[image_token_id] * num_image_tokens,
...@@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -719,7 +706,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt_ids, prompt_text, _ = self._apply_prompt_replacements( prompt_ids, prompt_text, _ = self._apply_prompt_replacements(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_repls, mantis_mm_repls,
mm_item_counts, mm_item_counts,
) )
...@@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -728,15 +715,19 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
orig_repls = self._bind_prompt_replacements(unbound_orig_repls) orig_repls = self._bind_and_group_repls(unbound_orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
prompt_ids,
mm_item_counts,
)
all_placeholders = self._find_placeholders(orig_repls, prompt_ids, self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
mm_item_counts)
assert len(all_placeholders) == mm_item_counts.get("image", 0)
mm_placeholders = { mm_placeholder_ranges = {
modality: [item.to_range() for item in items] modality: [item.to_range() for item in placeholders]
for modality, items in full_groupby_modality(all_placeholders) for modality, placeholders in mm_placeholders.items()
} }
return MultiModalInputsV2( return MultiModalInputsV2(
...@@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -744,7 +735,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
prompt=prompt_text, prompt=prompt_text,
prompt_token_ids=prompt_ids, prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs, mm_kwargs=mm_kwargs,
mm_placeholders=mm_placeholders, mm_placeholders=mm_placeholder_ranges,
) )
......
...@@ -67,9 +67,6 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -67,9 +67,6 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> LlavaNextProcessor: def _get_hf_processor(self) -> LlavaNextProcessor:
return self.ctx.get_hf_processor(LlavaNextProcessor) return self.ctx.get_hf_processor(LlavaNextProcessor)
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
...@@ -81,6 +78,9 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -81,6 +78,9 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_image_token(self) -> str:
return self._get_hf_processor().image_token
def _get_max_image_tokens(self) -> int: def _get_max_image_tokens(self) -> int:
largest_feature_size, _ = self._get_pinpoint_with_most_features() largest_feature_size, _ = self._get_pinpoint_with_most_features()
return largest_feature_size return largest_feature_size
...@@ -97,20 +97,20 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -97,20 +97,20 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
image_height: int, image_height: int,
) -> int: ) -> int:
hf_config = self._get_hf_config() hf_config = self._get_hf_config()
vision_encoder_info = self._vision_encoder_info
base_feature_size = self._apply_feature_select_strategy( base_feature_size = self._apply_feature_select_strategy(
hf_config.vision_feature_select_strategy, hf_config.vision_feature_select_strategy,
self._vision_encoder_info.get_num_image_tokens( vision_encoder_info.get_num_image_tokens(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
), ),
) )
num_patches = self._vision_encoder_info.get_num_patches()
num_patch_height, num_patch_width = get_anyres_image_grid_shape( num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(image_height, image_width), image_size=(image_height, image_width),
grid_pinpoints=hf_config.image_grid_pinpoints, grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=self._vision_encoder_info.get_image_size(), patch_size=vision_encoder_info.get_image_size(),
) )
( (
...@@ -119,7 +119,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -119,7 +119,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
) = self._get_num_unpadded_features( ) = self._get_num_unpadded_features(
original_height=image_height, original_height=image_height,
original_width=image_width, original_width=image_width,
npatches=num_patches, npatches=vision_encoder_info.get_patch_grid_length(),
num_patch_height=num_patch_height, num_patch_height=num_patch_height,
num_patch_width=num_patch_width, num_patch_width=num_patch_width,
) )
...@@ -155,6 +155,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -155,6 +155,7 @@ class LlavaNextMultiModalProcessor(LlavaMultiModalProcessor):
unpadded_features = current_height * current_width unpadded_features = current_height * current_width
newline_features = current_height newline_features = current_height
return (unpadded_features, newline_features) return (unpadded_features, newline_features)
def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]:
......
...@@ -3,38 +3,32 @@ from functools import cached_property ...@@ -3,38 +3,32 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, from transformers import (BatchFeature, LlavaNextVideoConfig,
SiglipVisionConfig) LlavaNextVideoProcessor)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
repeat_and_pad_placeholder_tokens) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import SiglipVisionModel
dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import BaseVisionLanguageMultiModalProcessor
# For profile run
_MAX_FRAMES_PER_VIDEO = 32
_MAX_NUM_VIDEOS = 1
class LlavaNextVideoPixelInputs(TypedDict): class LlavaNextVideoPixelInputs(TypedDict):
...@@ -50,143 +44,148 @@ class LlavaNextVideoPixelInputs(TypedDict): ...@@ -50,143 +44,148 @@ class LlavaNextVideoPixelInputs(TypedDict):
""" """
def get_llava_next_video_frame_feature_size( class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor):
hf_config: LlavaNextVideoConfig) -> int:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size = hf_config.vision_config.image_size
patch_size = hf_config.vision_config.patch_size
spatial_pool_stride = hf_config.spatial_pool_stride
return int((image_size / patch_size / spatial_pool_stride)**2) def _get_hf_config(self) -> LlavaNextVideoConfig:
return self.ctx.get_hf_config(LlavaNextVideoConfig)
def _get_hf_processor(self) -> LlavaNextVideoProcessor:
return self.ctx.get_hf_processor(LlavaNextVideoProcessor)
def _get_max_llm_tokens(ctx: InputContext) -> int: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
""" return {"video": 1}
Calculated from the maximum video frames under the context length
constraints of the language model. def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
""" num_frames = self._get_dummy_num_frames(seq_len)
hf_text_config = ctx.model_config.hf_text_config max_video_tokens = self._get_max_video_tokens(num_frames)
model_config = ctx.model_config
max_tokens = model_config.max_model_len return {"video": max_video_tokens}
rope_scaling = model_config.rope_scaling
def _get_mm_fields_config(
if rope_scaling: self,
rope_scaling_factor = hf_text_config.rope_scaling["factor"] hf_inputs: BatchFeature,
else: hf_processor_mm_kwargs: Mapping[str, object],
rope_scaling_factor = 1 ) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
max_tokens *= rope_scaling_factor
def _get_num_frame_tokens(
return max_tokens self,
*,
image_width: int,
def get_max_llava_next_video_tokens(ctx: InputContext) -> int: image_height: int,
# Currently set to 32 frames ) -> int:
# TODO: max_tokens = _get_max_llm_tokens(ctx) hf_config = self._get_hf_config()
hf_config = ctx.get_hf_config(LlavaNextVideoConfig) spatial_pool_stride = hf_config.spatial_pool_stride
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
# TODO: support multiple videos
num_videos = mm_counts["video"]
if num_videos != _MAX_NUM_VIDEOS:
raise NotImplementedError(
f"Only {_MAX_NUM_VIDEOS} videos are supported")
# TODO: support configuring the number of frames
frames_per_video = _MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = frames_per_video * tokens_per_frame
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video",
)
pil_frame = dummy_image_for_clip(vision_config, num_images=1) patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
np_frame = np.array(pil_frame["image"]) pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video} return pooled_grid_length * pooled_grid_length
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig): def _get_num_video_tokens(
seq_data, ranges = dummy_seq_data_for_siglip( self,
vision_config, *,
seq_len, image_width: int,
num_videos, image_height: int,
image_token_id=hf_config.video_token_index, num_frames: int,
image_feature_size_override=video_feature_size, ) -> int:
mm_key="video", num_frame_tokens = self._get_num_frame_tokens(
image_width=image_width,
image_height=image_height,
) )
pil_frame = dummy_image_for_siglip(vision_config, num_images=1) return num_frame_tokens * num_frames
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
mm_data = {"video": mm_data_per_video}
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}" def _get_max_video_tokens(self, num_frames: int) -> int:
raise NotImplementedError(msg) return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
def input_processor_for_llava_next_video(ctx: InputContext, while True:
inputs: DecoderOnlyInputs): next_num_frames = num_frames + 1
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "video" in inputs[ if self._get_max_video_tokens(next_num_frames) > max_tokens:
"multi_modal_placeholders"]: break
# The inputs already have placeholders.
return inputs
video_data = multi_modal_data["video"] num_frames = next_num_frames
model_config = ctx.model_config return num_frames
hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
vision_config = hf_config.vision_config
if isinstance(video_data, np.ndarray): def _get_dummy_num_frames(self, seq_len: int) -> int:
# Supports both CLIP and Siglip mm_config = self.ctx.get_mm_config()
num_frames = video_data.shape[0] max_videos = mm_config.limit_per_prompt.get("video", 1)
frame_feature_size = \
get_llava_next_video_frame_feature_size(hf_config)
video_feature_size = num_frames * frame_feature_size
tokenizer = cached_get_tokenizer(model_config.tokenizer) max_total_frames = self._get_max_video_frames(seq_len)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( return max(max_total_frames // max(max_videos, 1), 1)
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
)
return token_inputs(prompt_token_ids=new_token_ids, def _get_dummy_image_size(self) -> ImageSize:
prompt=new_prompt, image_size = self._vision_encoder_info.get_image_size()
multi_modal_data=multi_modal_data, return ImageSize(image_size, image_size)
multi_modal_placeholders={"video": ranges})
elif is_list_of(video_data, np.ndarray): def _get_video_token(self) -> str:
raise NotImplementedError( return self._get_hf_processor().video_token
"Processing multiple videos is not supported")
msg = f"Unsupported vision config: {type(vision_config)}" def _get_prompt_replacements(
raise NotImplementedError(msg) self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self._get_hf_config()
video_token_id = hf_config.video_token_index
def get_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
if isinstance(videos, VideoEmbeddingItems):
num_video_tokens = videos.get_feature_size(item_idx)
else:
image_size = videos.get_frame_size(item_idx)
num_video_tokens = self._get_num_video_tokens(
image_width=image_size.width,
image_height=image_size.height,
num_frames=videos.get_num_frames(item_idx),
)
return [video_token_id] * num_video_tokens
return [
PromptReplacement(
modality="video",
target=[video_token_id],
replacement=get_replacement,
),
]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_videos = mm_counts.get("video", 0)
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=video_token * num_videos,
mm_data=mm_data,
)
# adopted from transformers modeling_llava_next_video.py # adopted from transformers modeling_llava_next_video.py
...@@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -246,11 +245,7 @@ class LlavaNextMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_input_mapper("video") @MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
......
...@@ -3,47 +3,36 @@ from functools import cached_property ...@@ -3,47 +3,36 @@ from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union) TypedDict, Union)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from transformers import (BatchFeature, LlavaOnevisionConfig,
from transformers import (CLIPVisionConfig, LlavaOnevisionConfig, LlavaOnevisionProcessor)
SiglipVisionConfig)
from transformers.models.llava_onevision.modeling_llava_onevision import ( from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image) get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems,
repeat_and_pad_placeholder_tokens) VideoProcessorItems)
from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, from .clip import CLIPVisionModel
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .llava_next import LlavaNextMultiModalProcessor
dummy_video_for_siglip, get_siglip_image_feature_size, from .siglip import SiglipVisionModel
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
# For profile run
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict): class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"] type: Literal["pixel_values_videos"]
...@@ -92,286 +81,251 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs, ...@@ -92,286 +81,251 @@ LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs,
LlavaOnevisionVideoPixelInputs] LlavaOnevisionVideoPixelInputs]
def _get_llava_onevision_image_unppaded_feature_size(height, width, patches, class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor):
scale_height,
scale_width):
current_height = patches * scale_height
current_width = patches * scale_width
original_aspect_ratio = width / height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(height * (current_width / width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(width * (current_height / height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * patches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def get_llava_onevision_image_feature_size(
hf_config: LlavaOnevisionConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_clip_image_feature_size(vision_config)
elif isinstance(vision_config, SiglipVisionConfig):
num_patches = get_siglip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = get_siglip_image_feature_size(vision_config)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
strategy = hf_config.vision_feature_select_strategy
if strategy == "default":
base_feature_size -= 1
elif strategy == "full":
pass
else:
raise ValueError(f"Unexpected select feature strategy: {strategy}")
num_patch_height, num_patch_width = get_anyres_image_grid_shape( def _get_hf_config(self) -> LlavaOnevisionConfig:
image_size=(input_height, input_width), return self.ctx.get_hf_config(LlavaOnevisionConfig)
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
( def _get_hf_processor(self) -> LlavaOnevisionProcessor:
unpadded_feature_size, return self.ctx.get_hf_processor(LlavaOnevisionProcessor)
newline_feature_size,
) = _get_llava_onevision_image_unppaded_feature_size( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
input_height, input_width, num_patches, num_patch_height, return {"image": None, "video": None}
num_patch_width)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return unpadded_feature_size + newline_feature_size + base_feature_size max_image_tokens = self._get_max_image_tokens()
num_frames = self._get_dummy_num_frames(seq_len)
def get_max_llava_onevision_image_tokens(ctx: InputContext): max_video_tokens = self._get_max_video_tokens(num_frames)
return get_llava_onevision_image_feature_size(
ctx.get_hf_config(LlavaOnevisionConfig), return {
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, "image": max_image_tokens,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, "video": max_video_tokens,
) }
def _get_mm_fields_config(
def get_llava_onevision_video_frame_feature_size( self,
hf_config: LlavaOnevisionConfig) -> int: hf_inputs: BatchFeature,
# Support both CLIPVisionConfig and SiglipVisionConfig hf_processor_mm_kwargs: Mapping[str, object],
image_size = hf_config.vision_config.image_size ) -> Mapping[str, MultiModalFieldConfig]:
patch_size = hf_config.vision_config.patch_size return dict(
spatial_pool_stride = hf_config.spatial_pool_stride if hasattr( pixel_values=MultiModalFieldConfig.batched("image"),
hf_config, "spatial_pool_stride") else 2 image_sizes=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
height = width = image_size // patch_size pixel_values_videos=MultiModalFieldConfig.batched("video"),
return math.ceil(height / spatial_pool_stride) * math.ceil(
width / spatial_pool_stride)
def get_llava_onevision_video_tokens(ctx: InputContext,
num_frames: int) -> int:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
# TODO: support configuring (not supported by HF right now)
num_token_image_newline = 1
tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config)
video_feature_size = num_frames * tokens_per_frame + num_token_image_newline
return video_feature_size
def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int:
return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO)
def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
num_videos = mm_counts["video"]
# TODO: support configuring the number of frames
num_frames = _MAX_FRAMES_PER_VIDEO
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
if isinstance(vision_config, CLIPVisionConfig):
seq_data, ranges = dummy_seq_data_for_clip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_clip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
elif isinstance(vision_config, SiglipVisionConfig):
seq_data, ranges = dummy_seq_data_for_siglip(
vision_config,
seq_len,
num_videos,
image_token_id=hf_config.video_token_index,
image_feature_size_override=video_feature_size,
mm_key="video")
mm_data = dummy_video_for_siglip(vision_config,
num_frames=num_frames,
num_videos=num_videos)
return DummyData(seq_data, mm_data, ranges)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def input_processor_when_multimodal_input_image(ctx: InputContext,
inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
vision_config = hf_config.vision_config
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image_feature_size = get_llava_onevision_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif is_list_of(image_data, Image.Image):
image_feature_size = [
get_llava_onevision_image_feature_size(hf_config,
input_height=img.height,
input_width=img.width)
for img in image_data
]
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[1] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
) )
elif isinstance(vision_config, SiglipVisionConfig):
return input_processor_for_siglip( def _get_num_unpadded_features(
model_config, self,
vision_config, *,
inputs, original_height: int,
image_token_id=hf_config.image_token_index, original_width: int,
image_feature_size_override=image_feature_size, npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
new_height = int(original_height *
(current_width / original_width))
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
new_width = int(original_width *
(current_height / original_height))
padding = (current_width - new_width) // 2
current_width -= padding * 2
unpadded_features = current_height * current_width
newline_features = current_height
ratio = math.sqrt(current_height * current_width / (9 * npatches**2))
if ratio > 1.1:
unpadded_features = int(current_height // ratio) * int(
current_width // ratio)
newline_features = int(current_height // ratio)
return (unpadded_features, newline_features)
def _get_num_frame_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
hf_config = self._get_hf_config()
spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2)
patch_grid_length = self._vision_encoder_info.get_patch_grid_length()
pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride)
return pooled_grid_length * pooled_grid_length
def _get_num_video_tokens(
self,
*,
image_width: int,
image_height: int,
num_frames: int,
) -> int:
num_frame_tokens = self._get_num_frame_tokens(
image_width=image_width,
image_height=image_height,
) )
msg = f"Unsupported vision config: {type(vision_config)}" return num_frame_tokens * num_frames + 1 # Newline token
raise NotImplementedError(msg)
def _get_max_video_tokens(self, num_frames: int) -> int:
return self._get_num_video_tokens(image_width=999999,
image_height=999999,
num_frames=num_frames)
def _get_max_video_frames(self, max_tokens: int) -> int:
num_frames = 0
def input_processor_when_multimodal_input_video(ctx: InputContext, while True:
inputs: DecoderOnlyInputs): next_num_frames = num_frames + 1
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "video" not in multi_modal_data:
return inputs
video_data = multi_modal_data["video"]
model_config = ctx.model_config if self._get_max_video_tokens(next_num_frames) > max_tokens:
hf_config = ctx.get_hf_config(LlavaOnevisionConfig) break
if isinstance(video_data, np.ndarray): num_frames = next_num_frames
# Supports both CLIP and Siglip
num_frames = video_data.shape[0]
video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames)
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( return num_frames
tokenizer,
inputs.get("prompt"), def _get_dummy_num_frames(self, seq_len: int) -> int:
inputs["prompt_token_ids"], mm_config = self.ctx.get_mm_config()
placeholder_token_id=hf_config.video_token_index, max_images = mm_config.limit_per_prompt.get("image", 1)
repeat_count=video_feature_size, max_videos = mm_config.limit_per_prompt.get("video", 1)
max_image_tokens = self._get_max_image_tokens() * max_images
max_total_frames = self._get_max_video_frames(seq_len -
max_image_tokens)
return max(max_total_frames // max(max_videos, 1), 1)
def _get_video_token(self) -> str:
return self._get_hf_processor().video_token
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
assert isinstance(videos, list)
if not videos:
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
video_token = self._get_video_token()
# LLaVA-OneVision processor doesn't support multiple videos
# with different sizes when converting back to tensors
text_image_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
pixel_values_videos = []
for video in videos:
item_processor_data = dict(prompt=video_token, videos=video)
item_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=item_processor_data,
mm_kwargs=mm_kwargs,
)
pixel_values_videos.append(
item_outputs.pop("pixel_values_videos")[0])
combined_outputs = dict(
**text_image_outputs,
pixel_values_videos=pixel_values_videos,
) )
return BatchFeature(combined_outputs)
return token_inputs(prompt_token_ids=new_token_ids, def _get_prompt_replacements(
prompt=new_prompt, self,
multi_modal_data=multi_modal_data, mm_items: MultiModalDataItems,
multi_modal_placeholders={"video": ranges}) hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
elif is_list_of(video_data, np.ndarray): ) -> list[PromptReplacement]:
video_feature_size = [] image_repls = super()._get_prompt_replacements(
for video in video_data: mm_items=mm_items,
num_frames = video.shape[0] hf_processor_mm_kwargs=hf_processor_mm_kwargs,
video_feature_size.append( out_mm_kwargs=out_mm_kwargs,
get_llava_onevision_video_tokens(ctx, num_frames))
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=hf_config.video_token_index,
repeat_count=video_feature_size,
) )
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"video": ranges})
else:
raise TypeError(f"Invalid video type: {type(video_data)}")
msg = f"Unsupported video type: {type(video_data)}" hf_config = self._get_hf_config()
raise NotImplementedError(msg) video_token_id = hf_config.video_token_index
def get_video_replacement(item_idx: int):
videos = mm_items.get_items(
"video", (VideoEmbeddingItems, VideoProcessorItems))
def input_processor_for_llava_onevision(ctx: InputContext, if isinstance(videos, VideoEmbeddingItems):
inputs: DecoderOnlyInputs): num_video_tokens = videos.get_feature_size(item_idx)
multi_modal_data = inputs.get("multi_modal_data") else:
if multi_modal_data is None or ("video" not in multi_modal_data image_size = videos.get_frame_size(item_idx)
and "image" not in multi_modal_data): num_video_tokens = self._get_num_video_tokens(
return inputs image_width=image_size.width,
if "image" in multi_modal_data: image_height=image_size.height,
return input_processor_when_multimodal_input_image(ctx, inputs) num_frames=videos.get_num_frames(item_idx),
if "video" in multi_modal_data: )
return input_processor_when_multimodal_input_video(ctx, inputs)
return [video_token_id] * num_video_tokens
msg = "Unsupported multi data type" return image_repls + [
raise NotImplementedError(msg) PromptReplacement(
modality="video",
target=[video_token_id],
replacement=get_video_replacement,
),
]
def _get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
image_token = self._get_image_token()
video_token = self._get_video_token()
target_width, target_height = self._get_dummy_image_size()
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
"video":
self._get_dummy_videos(
width=target_width,
height=target_height,
num_frames=self._get_dummy_num_frames(seq_len),
num_videos=num_videos,
)
}
return ProcessorInputs(
prompt_text=image_token * num_images + video_token * num_videos,
mm_data=mm_data,
)
class LlavaOnevisionMultiModalProjector(nn.Module): class LlavaOnevisionMultiModalProjector(nn.Module):
...@@ -394,14 +348,7 @@ class LlavaOnevisionMultiModalProjector(nn.Module): ...@@ -394,14 +348,7 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
return hidden_states return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor)
@MULTIMODAL_REGISTRY.register_input_mapper("video")
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"image", get_max_llava_onevision_image_tokens)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment