Unverified Commit f5722a50 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[V1] Scatter and gather placeholders in the model runner (#15712)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 651cf0fe
...@@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( ...@@ -860,8 +860,8 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
) )
``` ```
To accommodate this, instead of a string you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails` To assign the vision embeddings to only the image tokens, instead of a string
with different `full` and `feature` attributes: you can return an instance of {class}`~vllm.multimodal.processing.PromptUpdateDetails`:
```python ```python
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
...@@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int): ...@@ -879,9 +879,9 @@ def get_replacement_fuyu(item_idx: int):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails( return PromptUpdateDetails.select_token_id(
full=image_tokens + [bos_token_id], image_tokens + [bos_token_id],
features=image_tokens, embed_token_id=_IMAGE_TOKEN_ID,
) )
``` ```
...@@ -914,9 +914,9 @@ def _get_prompt_updates( ...@@ -914,9 +914,9 @@ def _get_prompt_updates(
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails( return PromptUpdateDetails.select_token_id(
full=image_tokens + [bos_token_id], image_tokens + [bos_token_id],
features=image_tokens, embed_token_id=_IMAGE_TOKEN_ID,
) )
return [ return [
......
...@@ -989,9 +989,6 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -989,9 +989,6 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>+</sup> Multiple items can be inputted per text prompt for this modality. <sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{important} :::{important}
To use Gemma3 series models, you have to install Hugging Face Transformers library from source via
`pip install git+https://github.com/huggingface/transformers`.
Pan-and-scan image pre-processing is currently supported on V0 (but not V1). Pan-and-scan image pre-processing is currently supported on V0 (but not V1).
You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`. You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": True}'`.
::: :::
......
...@@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: ...@@ -47,7 +47,7 @@ def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
model=model_name, model=model_name,
trust_remote_code=True, trust_remote_code=True,
max_model_len=4096, max_model_len=4096,
max_num_seqs=5, max_num_seqs=2,
limit_mm_per_prompt={"audio": audio_count}, limit_mm_per_prompt={"audio": audio_count},
) )
......
...@@ -55,7 +55,10 @@ def server(request, audio_assets): ...@@ -55,7 +55,10 @@ def server(request, audio_assets):
for key, value in request.param.items() for key, value in request.param.items()
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME,
args,
env_dict={"VLLM_AUDIO_FETCH_TIMEOUT":
"30"}) as remote_server:
yield remote_server yield remote_server
......
...@@ -167,7 +167,7 @@ VLM_TEST_SETTINGS = { ...@@ -167,7 +167,7 @@ VLM_TEST_SETTINGS = {
"cherry_blossom": "<image>What is the season?", # noqa: E501 "cherry_blossom": "<image>What is the season?", # noqa: E501
}), }),
multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501 multi_image_prompt="<image><image>Describe the two images in detail.", # noqa: E501
max_model_len=8192, max_model_len=4096,
max_num_seqs=2, max_num_seqs=2,
auto_cls=AutoModelForImageTextToText, auto_cls=AutoModelForImageTextToText,
vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}} vllm_runner_kwargs={"mm_processor_kwargs": {"crop_to_patches": True}}
......
...@@ -176,6 +176,8 @@ def test_chat( ...@@ -176,6 +176,8 @@ def test_chat(
model, model,
dtype=dtype, dtype=dtype,
tokenizer_mode="mistral", tokenizer_mode="mistral",
load_format="mistral",
config_format="mistral",
max_model_len=max_model_len, max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model: ) as vllm_model:
...@@ -198,22 +200,14 @@ def test_chat( ...@@ -198,22 +200,14 @@ def test_chat(
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.parametrize( @pytest.mark.parametrize("prompt,expected_ranges",
"prompt,expected_ranges", [(_create_engine_inputs_hf(IMG_URLS[:1]),
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{ [PlaceholderRange(offset=11, length=494)]),
"offset": 11, (_create_engine_inputs_hf(IMG_URLS[1:4]), [
"length": 494 PlaceholderRange(offset=11, length=266),
}]), PlaceholderRange(offset=277, length=1056),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{ PlaceholderRange(offset=1333, length=418)
"offset": 11, ])])
"length": 266
}, {
"offset": 277,
"length": 1056
}, {
"offset": 1333,
"length": 418
}])])
def test_multi_modal_placeholders(vllm_runner, prompt, def test_multi_modal_placeholders(vllm_runner, prompt,
expected_ranges: list[PlaceholderRange], expected_ranges: list[PlaceholderRange],
monkeypatch) -> None: monkeypatch) -> None:
......
...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( ...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0] first_placeholder = image_placeholders[0]
# NOTE: There is a BOS token # NOTE: There is a BOS token
assert first_placeholder["offset"] == 1 assert first_placeholder.offset == 1
assert first_placeholder["length"] == ( assert first_placeholder.length == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
except Exception as exc: except Exception as exc:
......
...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one( ...@@ -92,8 +92,8 @@ def _validate_image_prompt_replacements_one(
first_placeholder = image_placeholders[0] first_placeholder = image_placeholders[0]
assert first_placeholder["offset"] == 0 assert first_placeholder.offset == 0
assert first_placeholder["length"] == len( assert first_placeholder.length == len(
processed_inputs["prompt_token_ids"]) // num_imgs processed_inputs["prompt_token_ids"]) // num_imgs
except Exception as exc: except Exception as exc:
failed_size_excs.append((image_size, exc)) failed_size_excs.append((image_size, exc))
......
...@@ -277,7 +277,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { ...@@ -277,7 +277,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True, trust_remote_code=True,
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
extras={"2b": "h2oai/h2ovl-mississippi-2b"}), # noqa: E501 extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
max_transformers_version="4.48", # noqa: E501
transformers_version_reason="HF model is not compatible."), # noqa: E501
"InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B", "InternVLChatModel": _HfExamplesInfo("OpenGVLab/InternVL2-1B",
extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501 extras={"2B": "OpenGVLab/InternVL2-2B"}, # noqa: E501
trust_remote_code=True), trust_remote_code=True),
......
...@@ -785,6 +785,7 @@ def test_find_update_tokens( ...@@ -785,6 +785,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_4": [ "pattern_4": [
...@@ -793,6 +794,7 @@ def test_find_update_tokens( ...@@ -793,6 +794,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=3, start_idx=3,
tokens=[32000], tokens=[32000],
is_embed=None,
), ),
], ],
} }
...@@ -807,12 +809,14 @@ def test_find_update_tokens( ...@@ -807,12 +809,14 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
PlaceholderFeaturesInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=5, start_idx=5,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_3": [ "pattern_3": [
...@@ -821,6 +825,7 @@ def test_find_update_tokens( ...@@ -821,6 +825,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=7, start_idx=7,
tokens=[1550, 918, 1550], tokens=[1550, 918, 1550],
is_embed=None,
), ),
], ],
# No match for pattern_4 as it has lower priority than pattern_1 # No match for pattern_4 as it has lower priority than pattern_1
...@@ -835,12 +840,14 @@ def test_find_update_tokens( ...@@ -835,12 +840,14 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=1, start_idx=1,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
PlaceholderFeaturesInfo( PlaceholderFeaturesInfo(
modality="pattern_1", modality="pattern_1",
item_idx=1, item_idx=1,
start_idx=3, start_idx=3,
tokens=[32000, 32000], tokens=[32000, 32000],
is_embed=None,
), ),
], ],
"pattern_4": [ "pattern_4": [
...@@ -849,6 +856,7 @@ def test_find_update_tokens( ...@@ -849,6 +856,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=5, start_idx=5,
tokens=[32000], tokens=[32000],
is_embed=None,
), ),
], ],
"pattern_3": [ "pattern_3": [
...@@ -857,6 +865,7 @@ def test_find_update_tokens( ...@@ -857,6 +865,7 @@ def test_find_update_tokens(
item_idx=0, item_idx=0,
start_idx=6, start_idx=6,
tokens=[1550, 918, 1550], tokens=[1550, 918, 1550],
is_embed=None,
), ),
], ],
} }
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import pytest import pytest
import torch import torch
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import sha256 from vllm.utils import sha256
# disable yapf here as it formats differently than isort such that both fail # disable yapf here as it formats differently than isort such that both fail
...@@ -158,13 +158,10 @@ def test_generate_block_hash_extra_keys(): ...@@ -158,13 +158,10 @@ def test_generate_block_hash_extra_keys():
request = make_request( request = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(20)], prompt_token_ids=[_ for _ in range(20)],
mm_positions=[{ mm_positions=[
"offset": 0, PlaceholderRange(offset=0, length=5),
"length": 5 PlaceholderRange(offset=10, length=5),
}, { ],
"offset": 10,
"length": 5
}],
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
...@@ -222,13 +219,10 @@ def test_hash_request_tokens(hash_fn): ...@@ -222,13 +219,10 @@ def test_hash_request_tokens(hash_fn):
request = make_request( request = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{ mm_positions=[
"offset": 0, PlaceholderRange(offset=0, length=3),
"length": 3 PlaceholderRange(offset=3, length=3),
}, { ],
"offset": 3,
"length": 3
}],
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
...@@ -253,25 +247,19 @@ def test_hash_tokens_different_mm_input(hash_fn): ...@@ -253,25 +247,19 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request( request1 = make_request(
request_id=0, request_id=0,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{ mm_positions=[
"offset": 0, PlaceholderRange(offset=0, length=3),
"length": 3 PlaceholderRange(offset=3, length=3),
}, { ],
"offset": 3,
"length": 3
}],
mm_hashes=["hash1", "hash2"], mm_hashes=["hash1", "hash2"],
) )
request2 = make_request( request2 = make_request(
request_id=1, request_id=1,
prompt_token_ids=[_ for _ in range(6)], prompt_token_ids=[_ for _ in range(6)],
mm_positions=[{ mm_positions=[
"offset": 0, PlaceholderRange(offset=0, length=3),
"length": 3 PlaceholderRange(offset=3, length=3),
}, { ],
"offset": 3,
"length": 3
}],
mm_hashes=["hash3", "hash2"], mm_hashes=["hash3", "hash2"],
) )
block_size = 3 block_size = 3
......
...@@ -27,7 +27,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -27,7 +27,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
encode_tokens) PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -35,7 +35,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP ...@@ -35,7 +35,6 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class AyaVisionImagePixelInputs(TypedDict): class AyaVisionImagePixelInputs(TypedDict):
...@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict): ...@@ -51,13 +50,6 @@ class AyaVisionImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class AyaVisionMultiModalProjector(nn.Module): class AyaVisionMultiModalProjector(nn.Module):
...@@ -135,21 +127,20 @@ class AyaVisionProcessingInfo(BaseProcessingInfo): ...@@ -135,21 +127,20 @@ class AyaVisionProcessingInfo(BaseProcessingInfo):
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor() hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor image_processor = hf_processor.image_processor
image_size = self.get_image_size_with_most_features() image_size = self.get_image_size_with_most_features()
tokenizer = hf_processor.tokenizer
num_patches = self.get_num_patches( num_patches = self.get_num_patches(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
size=image_processor.size, size=image_processor.size,
min_patches=image_processor.min_patches, min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches) max_patches=image_processor.max_patches,
image_string = hf_processor._prompt_split_image(num_patches)
x = encode_tokens(
tokenizer,
image_string,
add_special_tokens=False,
) )
return len(x)
img_patches_per_tile = (hf_processor.img_size //
hf_processor.patch_size)**2
return num_patches * img_patches_per_tile
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
...@@ -221,7 +212,6 @@ class AyaVisionMultiModalProcessor( ...@@ -221,7 +212,6 @@ class AyaVisionMultiModalProcessor(
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = hf_processor.image_processor image_processor = hf_processor.image_processor
hf_config = self.info.get_hf_config()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM # HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images := if (images :=
mm_data.get("images")) is not None and '<image>' in prompt: mm_data.get("images")) is not None and '<image>' in prompt:
...@@ -234,6 +224,7 @@ class AyaVisionMultiModalProcessor( ...@@ -234,6 +224,7 @@ class AyaVisionMultiModalProcessor(
parsed_images.get_image_size(i) parsed_images.get_image_size(i)
for i in range(len(parsed_images)) for i in range(len(parsed_images))
] ]
num_patches = [ num_patches = [
self.info.get_num_patches( self.info.get_num_patches(
image_width=image_size.width, image_width=image_size.width,
...@@ -243,20 +234,6 @@ class AyaVisionMultiModalProcessor( ...@@ -243,20 +234,6 @@ class AyaVisionMultiModalProcessor(
max_patches=image_processor.max_patches) max_patches=image_processor.max_patches)
for image_size in image_sizes for image_size in image_sizes
] ]
image_tokens_list = [
hf_processor._prompt_split_image(num_patch)
for num_patch in num_patches
]
tokenizer = self.info.get_tokenizer()
image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in image_tokens_list
]
embed_is_patch = [
torch.tensor(image_repl_tokens) == hf_config.image_token_index
for image_repl_tokens in image_token_ids
]
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["num_patches"] = torch.tensor(num_patches) processed_outputs["num_patches"] = torch.tensor(num_patches)
return processed_outputs return processed_outputs
...@@ -271,7 +248,6 @@ class AyaVisionMultiModalProcessor( ...@@ -271,7 +248,6 @@ class AyaVisionMultiModalProcessor(
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches), "image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -283,6 +259,7 @@ class AyaVisionMultiModalProcessor( ...@@ -283,6 +259,7 @@ class AyaVisionMultiModalProcessor(
) -> Sequence[PromptUpdate]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token image_token = hf_processor.image_token
img_patch_token = hf_processor.img_patch_token
image_processor = hf_processor.image_processor image_processor = hf_processor.image_processor
def get_replacement(item_idx: int): def get_replacement(item_idx: int):
...@@ -294,8 +271,11 @@ class AyaVisionMultiModalProcessor( ...@@ -294,8 +271,11 @@ class AyaVisionMultiModalProcessor(
image_height=image_size.height, image_height=image_size.height,
size=image_processor.size, size=image_processor.size,
min_patches=image_processor.min_patches, min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches) max_patches=image_processor.max_patches,
return hf_processor._prompt_split_image(num_patches=num_patches) )
repl = hf_processor._prompt_split_image(num_patches=num_patches)
return PromptUpdateDetails.select_text(repl, img_patch_token)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -424,7 +404,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -424,7 +404,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]: self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None) num_patches = kwargs.pop("num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds." assert image_embeds is None, "Aya Vision does not support image_embeds."
...@@ -436,18 +415,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -436,18 +415,13 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of num_patches. " raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}") f"Got type: {type(num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True) num_patches = flatten_bn(num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return AyaVisionImagePixelInputs( return AyaVisionImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_patches, num_patches=num_patches,
embed_is_patch=embed_is_patch,
) )
def get_multimodal_embeddings( def get_multimodal_embeddings(
...@@ -455,11 +429,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -455,11 +429,8 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input, **kwargs)
return scatter_patch_features( return self._process_image_input(image_input, **kwargs)
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -471,9 +442,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -471,9 +442,9 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
multimodal_embeddings=select_patch_features( multimodal_embeddings=multimodal_embeddings,
multimodal_embeddings), placeholder_token_id=self.config.image_token_index,
placeholder_token_id=self.config.image_token_index) )
return inputs_embeds return inputs_embeds
......
...@@ -162,9 +162,9 @@ class ChameleonMultiModalProcessor( ...@@ -162,9 +162,9 @@ class ChameleonMultiModalProcessor(
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id], target=[image_token_id],
replacement=PromptUpdateDetails( replacement=PromptUpdateDetails.select_token_id(
full=([image_start_id] + image_tokens + [image_end_id]), [image_start_id] + image_tokens + [image_end_id],
features=image_tokens, embed_token_id=image_token_id,
), ),
) )
] ]
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, Set, Tuple, TypedDict, Union from typing import Literal, Optional, Set, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -43,7 +43,6 @@ from vllm.sequence import IntermediateTensors ...@@ -43,7 +43,6 @@ from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -66,14 +65,6 @@ class FuyuImagePatchInputs(TypedDict): ...@@ -66,14 +65,6 @@ class FuyuImagePatchInputs(TypedDict):
flattened just like `flat_data`. flattened just like `flat_data`.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class FuyuProcessingInfo(BaseProcessingInfo): class FuyuProcessingInfo(BaseProcessingInfo):
...@@ -94,15 +85,7 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -94,15 +85,7 @@ class FuyuProcessingInfo(BaseProcessingInfo):
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> Mapping[str, int]: ) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features() return {"image": self.get_max_image_tokens()}
max_ncols, max_nrows = self.get_image_feature_grid_size(
image_width=target_width,
image_height=target_height,
)
max_image_tokens = (max_ncols + 1) * max_nrows
return {"image": max_image_tokens}
def get_image_feature_grid_size( def get_image_feature_grid_size(
self, self,
...@@ -128,11 +111,32 @@ class FuyuProcessingInfo(BaseProcessingInfo): ...@@ -128,11 +111,32 @@ class FuyuProcessingInfo(BaseProcessingInfo):
nrows = math.ceil(image_height / patch_height) nrows = math.ceil(image_height / patch_height)
return ncols, nrows return ncols, nrows
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
ncols, nrows = self.get_image_feature_grid_size(
image_width=image_width,
image_height=image_height,
)
return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor() image_processor = self.get_image_processor()
return ImageSize(width=image_processor.size["width"], return ImageSize(width=image_processor.size["width"],
height=image_processor.size["height"]) height=image_processor.size["height"])
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
)
class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]): class FuyuDummyInputsBuilder(BaseDummyInputsBuilder[FuyuProcessingInfo]):
...@@ -192,19 +196,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -192,19 +196,6 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
processed_outputs["image_patches"] = image_patches[0] processed_outputs["image_patches"] = image_patches[0]
# get patch grid size for each image
embed_is_patch = []
for image in images:
ncols, nrows = self.info.get_image_feature_grid_size(
image_width=image.width,
image_height=image.height,
)
mask = torch.tensor(([True] * ncols + [False]) * nrows)
embed_is_patch.append(mask)
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _apply_hf_processor_tokens_only( def _apply_hf_processor_tokens_only(
...@@ -224,8 +215,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -224,8 +215,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"), return dict(image_patches=MultiModalFieldConfig.batched("image"))
embed_is_patch=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates( def _get_prompt_updates(
self, self,
...@@ -252,9 +242,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -252,9 +242,9 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptUpdateDetails( return PromptUpdateDetails.select_token_id(
full=image_tokens + [bos_token_id], image_tokens + [bos_token_id],
features=image_tokens, embed_token_id=_IMAGE_TOKEN_ID,
) )
return [ return [
...@@ -329,20 +319,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -329,20 +319,13 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image patches. " raise ValueError("Incorrect type of image patches. "
f"Got type: {type(image_patches)}") f"Got type: {type(image_patches)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
image_patches_flat = flatten_bn(image_patches) image_patches_flat = flatten_bn(image_patches)
embed_is_patch = flatten_bn(embed_is_patch)
return FuyuImagePatchInputs( return FuyuImagePatchInputs(
type="image_patches", type="image_patches",
flat_data=self._validate_pixel_values( flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)), flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat], patches_per_image=[x.size(0) for x in image_patches_flat],
embed_is_patch=embed_is_patch,
) )
return None return None
...@@ -364,12 +347,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -364,12 +347,7 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -379,8 +357,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -379,8 +357,11 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None: if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, input_ids,
select_patch_features(multimodal_embeddings), _IMAGE_TOKEN_ID) inputs_embeds,
multimodal_embeddings,
_IMAGE_TOKEN_ID,
)
return inputs_embeds return inputs_embeds
def forward( def forward(
......
...@@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails, PromptUpdate, PromptUpdateDetails,
encode_tokens, find_mm_placeholders, find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
...@@ -36,7 +36,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA, ...@@ -36,7 +36,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict): ...@@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
Gemma3ImageInputs = Gemma3ImagePixelInputs Gemma3ImageInputs = Gemma3ImagePixelInputs
...@@ -183,7 +174,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -183,7 +174,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if processor is None: if processor is None:
processor = self.get_hf_processor() processor = self.get_hf_processor()
image_token = processor.boi_token boi_token = processor.boi_token
num_crops = self.get_num_crops( num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
...@@ -192,19 +183,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -192,19 +183,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
) )
if num_crops == 0: if num_crops == 0:
image_text = image_token image_text = boi_token
else: else:
crops_image_tokens = " ".join(image_token crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
for _ in range(num_crops))
image_text = ( image_text = (
f"Here is the original image {image_token} and here are some " f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}") f"crops to help you see better {crops_image_tokens}")
repl_full = image_text.replace(image_token, repl_full = image_text.replace(boi_token,
processor.full_image_sequence) processor.full_image_sequence)
repl_features = repl_full.strip("\n")
return PromptUpdateDetails(full=repl_full, features=repl_features) tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
def get_num_image_tokens( def get_num_image_tokens(
self, self,
...@@ -213,19 +206,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): ...@@ -213,19 +206,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
processor: Optional[Gemma3Processor], processor: Optional[Gemma3Processor],
) -> int: ) -> int:
tokenizer = self.get_tokenizer() if processor is None:
image_repl = self.get_image_repl( processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
) )
image_seq_len = processor.image_seq_length
image_repl_tokens = encode_tokens( return (num_crops + 1) * image_seq_len
tokenizer,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -301,28 +292,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -301,28 +292,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
] ]
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [ num_crops = [
self.info.get_num_crops(image_width=size.width, self.info.get_num_crops(image_width=size.width,
image_height=size.height, image_height=size.height,
...@@ -344,7 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -344,7 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1), "image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"), num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -454,6 +422,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): ...@@ -454,6 +422,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx=p.item_idx, item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx], start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens, tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders ) for p in placeholders
] ]
for modality, placeholders in repls.items() for modality, placeholders in repls.items()
...@@ -572,7 +541,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -572,7 +541,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]: self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None) pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None) num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds." assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None: if pixel_values is None:
...@@ -586,19 +554,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -586,19 +554,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. " raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}") f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True) pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True) num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs( return Gemma3ImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1, num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
) )
def _image_pixels_to_features( def _image_pixels_to_features(
...@@ -635,12 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -635,12 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -652,7 +609,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -652,7 +609,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -257,7 +257,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
def resolve_min_max_num( def resolve_min_max_num(
self, self,
......
...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, MultiModalDataItems,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement, PromptUpdate, PromptReplacement, PromptUpdate,
encode_tokens) PromptUpdateDetails)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal ...@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict): class Idefics3ImagePixelInputs(TypedDict):
...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict): ...@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict): class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict): ...@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone. `hidden_size` must match the hidden size of language model backbone.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs] ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
...@@ -275,19 +258,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo): ...@@ -275,19 +258,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int, image_height: int,
processor: Optional[Idefics3Processor], processor: Optional[Idefics3Processor],
) -> int: ) -> int:
tokenizer = self.get_tokenizer() if processor is None:
image_repl = self.get_image_repl( processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width, image_width=image_width,
image_height=image_height, image_height=image_height,
processor=processor, processor=processor,
) )
image_repl_tokens = encode_tokens( return num_patches * processor.image_seq_len
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize: def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor() processor = self.get_hf_processor()
...@@ -364,28 +344,6 @@ class Idefics3MultiModalProcessor( ...@@ -364,28 +344,6 @@ class Idefics3MultiModalProcessor(
] ]
hf_processor = self.info.get_hf_processor(**mm_kwargs) hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [ num_patches = [
self.info.get_num_patches( self.info.get_num_patches(
image_width=size.width, image_width=size.width,
...@@ -415,7 +373,6 @@ class Idefics3MultiModalProcessor( ...@@ -415,7 +373,6 @@ class Idefics3MultiModalProcessor(
"image", num_patches), "image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"), num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_updates( def _get_prompt_updates(
...@@ -427,17 +384,22 @@ class Idefics3MultiModalProcessor( ...@@ -427,17 +384,22 @@ class Idefics3MultiModalProcessor(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content image_token = hf_processor.image_token.content
def get_replacement_idefics3(item_idx: int) -> str: def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
return self.info.get_image_repl( image_repl = self.info.get_image_repl(
image_width=image_size.width, image_width=image_size.width,
image_height=image_size.height, image_height=image_size.height,
processor=hf_processor, processor=hf_processor,
) )
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [ return [
PromptReplacement( PromptReplacement(
modality="image", modality="image",
...@@ -675,13 +637,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -675,13 +637,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. " raise ValueError("Incorrect type of image embeddings. "
...@@ -690,7 +645,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -690,7 +645,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs( return Idefics3ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds, concat=True), data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
) )
if pixel_values is not None: if pixel_values is not None:
...@@ -718,7 +672,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -718,7 +672,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values), pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask, pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches, num_patches=num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -754,12 +707,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -754,12 +707,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -771,7 +719,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -771,7 +719,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_id, self.config.image_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -39,7 +39,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -39,7 +39,6 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict): ...@@ -60,14 +59,6 @@ class InternVLImagePixelInputs(TypedDict):
num_patches: torch.Tensor num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`""" """Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class InternVLImageEmbeddingInputs(TypedDict): class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC): ...@@ -419,24 +410,12 @@ class BaseInternVLProcessor(ABC):
torch.tensor([len(item) for item in pixel_values_lst]), torch.tensor([len(item) for item in pixel_values_lst]),
} }
tokenizer = self.tokenizer
image_token_id = self.image_token_id
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst: for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0] num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl(feature_size, num_patches) image_repl = self.get_image_repl(feature_size, num_patches)
feature_tokens = tokenizer.encode(image_repl.features,
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text] text = [t.replace('<image>', image_repl.full, 1) for t in text]
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor): ...@@ -460,7 +439,7 @@ class InternVLProcessor(BaseInternVLProcessor):
repl_features = IMG_CONTEXT * feature_size repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END repl_full = IMG_START + repl_features + IMG_END
return PromptUpdateDetails(full=repl_full, features=repl_features) return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
class BaseInternVLProcessingInfo(BaseProcessingInfo): class BaseInternVLProcessingInfo(BaseProcessingInfo):
...@@ -599,7 +578,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -599,7 +578,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
...@@ -831,7 +809,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -831,7 +809,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self, **kwargs: object) -> Optional[InternVLImageInputs]: self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None) pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None) image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None) image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None: if pixel_values_flat is None and image_embeds is None:
...@@ -860,20 +837,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -860,20 +837,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of image_num_patches. " raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}") f"Got type: {type(image_num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True) image_num_patches = flatten_bn(image_num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=self._validate_pixel_values( pixel_values_flat=self._validate_pixel_values(
pixel_values_flat), pixel_values_flat),
num_patches=image_num_patches, num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
) )
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
...@@ -919,15 +890,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -919,15 +890,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values":
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -941,7 +904,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -941,7 +904,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.img_context_token_id, self.img_context_token_id,
) )
return inputs_embeds return inputs_embeds
......
...@@ -32,7 +32,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -32,7 +32,8 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel ...@@ -42,8 +43,7 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import (get_vision_encoder_info, scatter_patch_features, from .vision import get_vision_encoder_info
select_patch_features)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict): ...@@ -67,14 +67,6 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor. in which case the data is passed as a list instead of a batched tensor.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class LlavaImageEmbeddingInputs(TypedDict): class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
...@@ -343,23 +335,6 @@ class PixtralHFMultiModalProcessor( ...@@ -343,23 +335,6 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes) for p, (h, w) in zip(pixel_values, image_sizes)
] ]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
...@@ -369,7 +344,6 @@ class PixtralHFMultiModalProcessor( ...@@ -369,7 +344,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
...@@ -404,7 +378,7 @@ class PixtralHFMultiModalProcessor( ...@@ -404,7 +378,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id tokens[-1] = image_end_id
return tokens return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return [ return [
PromptReplacement( PromptReplacement(
...@@ -612,17 +586,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -612,17 +586,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}") f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral": if self.config.vision_config.model_type == "pixtral":
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs( return PixtralHFImagePixelInputs(
type="pixel_values_pixtral", type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
) )
return LlavaImagePixelInputs( return LlavaImagePixelInputs(
...@@ -714,16 +680,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -714,16 +680,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None: if image_input is None:
return None return None
image_features = self._process_image_input(image_input) return self._process_image_input(image_input)
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings( def get_input_embeddings(
self, self,
...@@ -735,7 +692,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -735,7 +692,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, input_ids,
inputs_embeds, inputs_embeds,
select_patch_features(multimodal_embeddings), multimodal_embeddings,
self.config.image_token_index, self.config.image_token_index,
) )
return inputs_embeds return inputs_embeds
......
...@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems, ...@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData, DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems, ModalityDataItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6, ...@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config) _minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn, from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix) maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict): ...@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]` which equals to `audio_features.shape[-1]`
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict): class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"] type: Literal["audio_embeds"]
...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict): ...@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list. Length of each slice may vary, so pass it as a list.
""" """
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs] MiniCPMOAudioEmbeddingInputs]
...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): ...@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios), audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
...@@ -197,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -197,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step() pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100 fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1 cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1 return (cnn_feat_in_chunk - pool_step) // pool_step + 1
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
def get_max_audio_chunks_with_most_features(self) -> int: def get_max_audio_chunks_with_most_features(self) -> int:
return 30 return 30
...@@ -209,8 +191,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo): ...@@ -209,8 +191,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int: def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate() sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio> num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1 return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features( def get_num_frames_with_most_features(
...@@ -295,13 +276,6 @@ class MiniCPMOMultiModalProcessor( ...@@ -295,13 +276,6 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems): if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {} audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else: else:
audio_inputs = self._base_call_hf_processor( audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios), prompts=[self.info.audio_pattern] * len(parsed_audios),
...@@ -323,27 +297,7 @@ class MiniCPMOMultiModalProcessor( ...@@ -323,27 +297,7 @@ class MiniCPMOMultiModalProcessor(
] ]
audio_inputs["audio_features"] = unpadded_audio_features audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"] unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id) audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
...@@ -384,7 +338,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -384,7 +338,10 @@ class MiniCPMOMultiModalProcessor(
else: else:
audio_len = audios.get_audio_length(item_idx) audio_len = audios.get_audio_length(item_idx)
return self.get_audio_prompt_texts(audio_len) return PromptUpdateDetails.select_text(
self.get_audio_prompt_texts(audio_len),
"<unk>",
)
return [ return [
*base_updates, *base_updates,
...@@ -713,13 +670,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -713,13 +670,6 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor) assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item()) self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)): if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. " raise ValueError("Incorrect type of audio_embeds. "
...@@ -730,7 +680,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -730,7 +680,6 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=audio_embeds_flat, audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
) )
if not isinstance(audio_features, (torch.Tensor, list)): if not isinstance(audio_features, (torch.Tensor, list)):
...@@ -749,7 +698,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -749,7 +698,6 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features", type="audio_features",
audio_features=audio_features_flat, audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat, audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
...@@ -781,10 +729,6 @@ class MiniCPMO(MiniCPMV2_6): ...@@ -781,10 +729,6 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios": if modality == "audios":
audio_input = modalities["audios"] audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input) audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple( multimodal_embeddings += tuple(audio_features)
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings return multimodal_embeddings
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment