Unverified Commit f1579b22 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Generalized prompt updates for multi-modal processor (#13964)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 78648758
...@@ -720,13 +720,13 @@ def _get_mm_fields_config( ...@@ -720,13 +720,13 @@ def _get_mm_fields_config(
::::: :::::
### Prompt replacements ### Prompt updates
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances. return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
operation performed by the HF processor. (e.g.: insertion, replacement) performed by the HF processor.
::::{tab-set} ::::{tab-set}
:::{tab-item} Basic example: LLaVA :::{tab-item} Basic example: LLaVA
...@@ -743,15 +743,15 @@ for sample in text: ...@@ -743,15 +743,15 @@ for sample in text:
``` ```
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows: Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
```python ```python
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch( ...@@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
) )
``` ```
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails` To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails`
with different `full` and `feature` attributes: with different `full` and `feature` attributes:
```python ```python
...@@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int): ...@@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails( return PromptUpdateDetails(
full=image_tokens + [bos_token_id], full=image_tokens + [bos_token_id],
features=image_tokens, features=image_tokens,
) )
...@@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the ...@@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the
we can search for it to conduct the replacement at the start of the string: we can search for it to conduct the replacement at the start of the string:
```python ```python
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
...@@ -913,7 +913,7 @@ def _get_prompt_replacements( ...@@ -913,7 +913,7 @@ def _get_prompt_replacements(
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails( return PromptUpdateDetails(
full=image_tokens + [bos_token_id], full=image_tokens + [bos_token_id],
features=image_tokens, features=image_tokens,
) )
......
...@@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi ...@@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`: Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
## Prompt Replacement Detection ## Prompt Update Detection
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs. One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example:
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt. - Insert feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size) at the start of the string.
- Replace existing input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size).
The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens.
## Tokenized Prompt Inputs ## Tokenized Prompt Inputs
...@@ -22,7 +27,7 @@ Consider that HF processors follow these main steps: ...@@ -22,7 +27,7 @@ Consider that HF processors follow these main steps:
1. Tokenize the text 1. Tokenize the text
2. Process multi-modal inputs 2. Process multi-modal inputs
3. Perform prompt replacement 3. Perform prompt updates
And we require that: And we require that:
...@@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h ...@@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data. We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
(mm-automatic-prompt-replacement)= (mm-automatic-prompt-updating)=
### Automatic prompt replacement ### Automatic prompt updating
We address the second issue by implementing model-agnostic code in We address the second issue by implementing model-agnostic code in
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
### Summary ### Summary
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`. With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
## Processor Output Caching ## Processor Output Caching
...@@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238) ...@@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238)
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache. When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other. Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other.
...@@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo, from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement, PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
iter_token_matches, iter_token_matches)
replace_text_matches,
replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer,
...@@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{ {
"pattern_1": [], "pattern_1": [],
"pattern_2": [], "pattern_2": [],
} },
), ),
( (
[32000, 32000, 32000, 32000], [32000, 32000, 32000, 32000],
...@@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
), ),
], ],
) )
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable # yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key): def test_find_token_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to token IDs # Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_updates = [
PromptReplacement(key, target, []).bind(mock_tokenizer) update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_token_matches(prompt, prompt_repls) result = find_token_matches(prompt, prompt_updates)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
...@@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key): ...@@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
), ),
], ],
) )
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable # yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key): def test_find_text_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [ prompt_updates = [
PromptReplacement(key, target, []).bind(mock_tokenizer) update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items() for key, target in target_by_key.items()
] ]
result = find_text_matches(prompt, prompt_repls) result = find_text_matches(prompt, prompt_updates)
# Only displayed on error # Only displayed on error
print("result:", result) print("result:", result)
...@@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"), ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[ [
( (
"Image:<image>Image:<image><image>!", "Image:<image>Image:<image><image>!",
...@@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key): ...@@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Test dynamic replacement (beyond the form of `unit * count`) # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?", "pattern_3": "?!?",
}, },
{
PromptInsertion: {
0: "Image:<image>Image:<image><image>!",
1: "Image:<image><image><image>Image:<image><image>!?!?",
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
},
PromptReplacement: {
0: "Image:<image>Image:<image><image>!",
1: "<image><image>Image:<image><image>?!?",
2: "<image><image><image><image><image>?!?",
},
},
), ),
] ]
) )
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable # yapf: enable
def test_find_replace_text( def test_find_update_text(
prompt, prompt,
target_by_key, target_by_key,
repl_by_key, repl_by_key,
mm_count, expected_by_update_type_mm_count,
expected,
): ):
# Should not be used since there is nothing to convert to text # Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = { for (
key: [ update_type,
PromptReplacement(key, target, expected_by_mm_count,
repl_by_key[key]).bind(mock_tokenizer) ) in expected_by_update_type_mm_count.items():
] mm_prompt_updates = {
for key, target in target_by_key.items() key:
} [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
mm_matches = { for key, target in target_by_key.items()
key: find_text_matches(prompt, prompt_repls) }
for key, prompt_repls in mm_prompt_repls.items() mm_matches = {
} key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
result = replace_text_matches( }
prompt,
mm_matches, for mm_count, expected in expected_by_mm_count.items():
{key: mm_count result = apply_text_matches(
for key in repl_by_key}, prompt,
) mm_matches,
{key: mm_count
# Only displayed on error for key in repl_by_key},
print("mm_matches:", mm_matches) )
print("result:", result)
# Only displayed on error
# Manually constructed results print("update_type:", update_type)
assert result == expected print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# yapf: disable # yapf: disable
@pytest.mark.parametrize( @pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"), ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[ [
# Tokenized test cases of `test_find_replace_text` # Tokenized test cases of `test_find_replace_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
...@@ -372,53 +392,61 @@ def test_find_replace_text( ...@@ -372,53 +392,61 @@ def test_find_replace_text(
# Test dynamic replacement (beyond the form of `unit * count`) # Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550], "pattern_3": [1550, 918, 1550],
}, },
{
PromptInsertion: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
},
PromptReplacement: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
},
},
), ),
] ]
) )
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable # yapf: enable
def test_find_replace_tokens( def test_find_update_tokens(
prompt, prompt,
target_by_key, target_by_key,
repl_by_key, repl_by_key,
mm_count, expected_by_update_type_mm_count,
expected,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = { for (
key: [ update_type,
PromptReplacement(key, target, expected_by_mm_count,
repl_by_key[key]).bind(mock_tokenizer) ) in expected_by_update_type_mm_count.items():
] mm_prompt_updates = {
for key, target in target_by_key.items() key:
} [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
mm_matches = { for key, target in target_by_key.items()
key: find_token_matches(prompt, prompt_repls) }
for key, prompt_repls in mm_prompt_repls.items() mm_matches = {
} key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
result = replace_token_matches( }
prompt,
mm_matches, for mm_count, expected in expected_by_mm_count.items():
{key: mm_count result = apply_token_matches(
for key in repl_by_key}, prompt,
) mm_matches,
{key: mm_count
# Only displayed on error for key in repl_by_key},
print("mm_matches:", mm_matches) )
print("result:", result)
# Only displayed on error
# Manually constructed results print("update_type:", update_type)
assert result == expected print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# yapf: disable # yapf: disable
...@@ -524,22 +552,24 @@ def test_find_replace_tokens( ...@@ -524,22 +552,24 @@ def test_find_replace_tokens(
), ),
] ]
) )
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable # yapf: enable
def test_find_mm_placeholders( def test_find_mm_placeholders(
repl_by_key, repl_by_key,
prompt, prompt,
expected, expected,
update_type,
): ):
# Should not be used since there is nothing to convert to tokens # Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object()) mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = { mm_prompt_updates = {
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)] key: [update_type(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items() for key, repl in repl_by_key.items()
} }
result = find_mm_placeholders( result = find_mm_placeholders(
mm_prompt_repls, mm_prompt_updates,
prompt, prompt,
# Effectively match all occurrences in the prompt # Effectively match all occurrences in the prompt
{key: 3 {key: 3
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import List, Optional, Set, Tuple, TypedDict, Union
Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]): ...@@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
pixel_mask=MultiModalFieldConfig.batched("image"), pixel_mask=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple, from typing import Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptInsertion,
PromptReplacementDetails) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
image_token_id = vocab["<image>"] image_token_id = vocab["<image>"]
num_image_tokens = self.info.get_num_image_tokens() num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens image_tokens = [image_token_id] * num_image_tokens
return [ return [
PromptReplacement( PromptInsertion(
modality="image", modality="image",
target=[bos_token_id], target="",
replacement=PromptReplacementDetails( insertion=image_tokens,
full=image_tokens + [bos_token_id],
features=image_tokens,
),
) )
] ]
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set, from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor( ...@@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image")) return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab() vocab = tokenizer.get_vocab()
...@@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor( ...@@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor(
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id], target=[image_token_id],
replacement=PromptReplacementDetails( replacement=PromptUpdateDetails(
full=([image_start_id] + image_tokens + [image_end_id]), full=([image_start_id] + image_tokens + [image_end_id]),
features=image_tokens, features=image_tokens,
), ),
...@@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module): ...@@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py # adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights.""" """Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
...@@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor( ...@@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id image_token_id = hf_processor.image_token_id
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict, from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
Set, Tuple, TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs ...@@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement, PromptInsertion, PromptUpdate)
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder( ...@@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder(
class Florence2MultiModalProcessor( class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]): EncDecMultiModalProcessor[Florence2ProcessingInfo]):
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -850,26 +850,22 @@ class Florence2MultiModalProcessor( ...@@ -850,26 +850,22 @@ class Florence2MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image")) return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id pad_token_id = hf_config.pad_token_id
bos_token_id = hf_config.bos_token_id
num_image_tokens = self.info.get_max_image_tokens() num_image_tokens = self.info.get_max_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens image_tokens = [pad_token_id] * num_image_tokens
return [ return [
PromptReplacement( PromptInsertion(
modality="image", modality="image",
target=[bos_token_id], target="",
replacement=PromptReplacementDetails( insertion=image_tokens,
full=image_tokens + [bos_token_id],
features=image_tokens,
),
) )
] ]
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
# limitations under the License. # limitations under the License.
""" PyTorch Fuyu model.""" """ PyTorch Fuyu model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from collections.abc import Iterable, Mapping, Sequence
TypedDict) from typing import List, Literal, Optional, Set, Tuple, TypedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image")) return dict(image_patches=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int) assert isinstance(bos_token_id, int)
...@@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]): ...@@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols + image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows [_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails( return PromptUpdateDetails(
full=image_tokens + [bos_token_id], full=image_tokens + [bos_token_id],
features=image_tokens, features=image_tokens,
) )
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
# https://github.com/THUDM/CogAgent # https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights.""" """Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace from argparse import Namespace
from typing import Literal, Mapping, Optional, TypedDict, Union from collections.abc import Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems ...@@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature, BaseProcessingInfo, BatchFeature,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig from vllm.transformers_utils.configs import ChatGLMConfig
...@@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]): ...@@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]): ...@@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image")) return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id boi_token_id = hf_config.boi_token_id
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
# Copyright (c) 2024 H2O.AI # Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from typing import Mapping, Optional from collections.abc import Mapping, Sequence
from typing import Optional
import torch import torch
from PIL import Image from PIL import Image
...@@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs ...@@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ...@@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
f"{type(self).__name__} does not support processing cache with " f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.") "multi-image support enabled.")
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs: if "image_num_patches" in out_mm_kwargs:
...@@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ...@@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptReplacementDetails( return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size, full=hf_processor.get_image_repl_full(feature_size,
num_patches), num_patches),
features=hf_processor.get_image_repl_features( features=hf_processor.get_image_repl_features(
......
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights.""" """Inference-only Idefics3 model compatible with HuggingFace weights."""
import math import math
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set, from collections.abc import Iterable, Mapping, Sequence
Tuple, TypedDict, Union) from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalDataItems, MultiModalDataItems,
MultiModalFieldConfig, MultiModalFieldConfig,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor( ...@@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content image_token = hf_processor.image_token.content
......
...@@ -7,9 +7,10 @@ ...@@ -7,9 +7,10 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
TypedDict, TypeVar, Union) Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails) PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs: if "image_num_patches" in out_mm_kwargs:
...@@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None: if num_patches is not None:
assert isinstance(num_patches, int) assert isinstance(num_patches, int)
return PromptReplacementDetails( return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size, full=hf_processor.get_image_repl_full(feature_size,
num_patches), num_patches),
features=hf_processor.get_image_repl_features( features=hf_processor.get_image_repl_features(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
Protocol, Set, Tuple, TypedDict, TypeVar, Union) TypedDict, TypeVar, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ...@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, ProcessingCache,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError raise NotImplementedError
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index image_token_id = hf_config.image_token_index
...@@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor( ...@@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
...@@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens "</Image>)", # 3 tokens
]) ])
mantis_mm_repls = self._bind_and_group_repls([ mantis_mm_repls = self._bind_and_group_updates([
PromptReplacement( PromptReplacement(
modality="image", modality="image",
target=[image_token_id] * num_image_tokens, target=[image_token_id] * num_image_tokens,
...@@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor): ...@@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
) )
]) ])
prompt_ids, prompt, _ = self._apply_prompt_replacements( prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"], result["prompt_token_ids"],
mantis_mm_repls, mantis_mm_repls,
mm_item_counts, mm_item_counts,
) )
unbound_orig_repls = self._get_prompt_replacements( unbound_orig_repls = self._get_prompt_updates(
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
mm_kwargs, mm_kwargs,
) )
orig_repls = self._bind_and_group_repls(unbound_orig_repls) orig_repls = self._bind_and_group_updates(unbound_orig_repls)
mm_placeholders = self._find_mm_placeholders( mm_placeholders = self._find_mm_placeholders(
orig_repls, orig_repls,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor( ...@@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config() hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index video_token_id = hf_config.video_token_index
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional, from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
Protocol, Set, Tuple, TypedDict, Union) TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems) VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor( ...@@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor(
) )
return BatchFeature(combined_outputs) return BatchFeature(combined_outputs)
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> bool: ) -> bool:
base_result = super()._hf_processor_applies_repl( base_result = super()._hf_processor_applies_updates(
prompt_text=prompt_text, prompt_text=prompt_text,
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
...@@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor( ...@@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor(
return base_result and mm_items.get_count("video", strict=False) == 0 return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
image_repls = super()._get_prompt_replacements( image_repls = super()._get_prompt_updates(
mm_items=mm_items, mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs, out_mm_kwargs=out_mm_kwargs,
...@@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor( ...@@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor(
return [video_token_id] * num_video_tokens return [video_token_id] * num_video_tokens
return image_repls + [ return [
*image_repls,
PromptReplacement( PromptReplacement(
modality="video", modality="video",
target=[video_token_id], target=[video_token_id],
......
...@@ -22,9 +22,10 @@ ...@@ -22,9 +22,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights.""" """Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
Optional, Set, Tuple, TypedDict, Union) TypedDict, Union)
import torch import torch
from torch import nn from torch import nn
...@@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor( ...@@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor(
inputs["audio"]["audio_lens"][index]) inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index) return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_replacements( def _get_prompt_updates(
self, mm_items: MultiModalDataItems, self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = { placeholder = {
"image": self.info.image_pattern, "image": self.info.image_pattern,
"video": self.info.video_pattern, "video": self.info.video_pattern,
......
...@@ -25,9 +25,10 @@ ...@@ -25,9 +25,10 @@
import math import math
import re import re
from collections import Counter from collections import Counter
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
Optional, Set, Tuple, TypedDict, Union) TypedDict, Union)
import numpy as np import numpy as np
import torch import torch
...@@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
} }
} }
def _hf_processor_applies_repl( def _hf_processor_applies_updates(
self, self,
prompt_text: str, prompt_text: str,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> bool: ) -> bool:
return False return False
def _get_prompt_replacements( def _get_prompt_updates(
self, mm_items: MultiModalDataItems, self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any], hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]: out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = { placeholder = {
"image": self.info.image_pattern, "image": self.info.image_pattern,
"video": self.info.video_pattern, "video": self.info.video_pattern,
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch Mllama model.""" """PyTorch Mllama model."""
import math import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, from collections.abc import Iterable, Mapping, Sequence
TypedDict, Union) from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import numpy as np import numpy as np
import torch import torch
...@@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, ...@@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems) MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptReplacement) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP from .clip import CLIPMLP
...@@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo] ...@@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images return [image_token_id] * num_images
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
token_per_chunk = self.info.get_token_per_chunk_from_config() token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index image_token_id = self.info.get_hf_config().image_token_index
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import math import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property, partial from functools import cached_property, partial
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
Union, cast)
import numpy as np import numpy as np
import torch import torch
...@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptInsertion,
PromptReplacementDetails) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves from vllm.utils import JSONTree, json_map_leaves
...@@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo): ...@@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor) return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
# in multi-image case
return {"image": 1} return {"image": 1}
def get_mm_max_tokens_per_item( def get_mm_max_tokens_per_item(
...@@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
img_patch_id=MultiModalFieldConfig.shared("image", num_images), img_patch_id=MultiModalFieldConfig.shared("image", num_images),
) )
def _get_prompt_replacements( def _get_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs, out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]: ) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
image_token_length_w = processor.image_token_length_w image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size pooling_size = processor.pooling_size
user_str = "User:"
if processor.always_start_with_space:
user_str = " " + user_str
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
img_patch_id = processor.image_patch_id img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id img_col_id = processor.im_col_id
img_start_id = processor.im_start_id img_start_id = processor.im_start_id
...@@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
extra_joint = ([img_start_id] + extra_row * image_token_length_h + extra_joint = ([img_start_id] + extra_row * image_token_length_h +
[img_end_id]) [img_end_id])
def get_replacement_molmo(item_idx: int): def get_insertion_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems) images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx) image_size = images.get_image_size(item_idx)
...@@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
((nrows + 1) // pooling_size) + [img_end_id]) ((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint image_tokens = extra_joint + joint
return image_tokens
return PromptReplacementDetails(
full=image_tokens + user_tokens,
features=image_tokens,
)
return [ return [
PromptReplacement( PromptInsertion(
modality="image", modality="image",
target=user_str, target="<|endoftext|>",
replacement=get_replacement_molmo, insertion=get_insertion_molmo,
) )
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment