Unverified Commit f1579b22 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM] Generalized prompt updates for multi-modal processor (#13964)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 78648758
......@@ -720,13 +720,13 @@ def _get_mm_fields_config(
:::::
### Prompt replacements
### Prompt updates
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
operation performed by the HF processor.
Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
(e.g.: insertion, replacement) performed by the HF processor.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
......@@ -743,15 +743,15 @@ for sample in text:
```
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows:
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
```python
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......@@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
)
```
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails`
To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails`
with different `full` and `feature` attributes:
```python
......@@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
......@@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the
we can search for it to conduct the replacement at the start of the string:
```python
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
......@@ -913,7 +913,7 @@ def _get_prompt_replacements(
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
......
......@@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
## Prompt Replacement Detection
## Prompt Update Detection
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example:
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
- Insert feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size) at the start of the string.
- Replace existing input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size).
The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens.
## Tokenized Prompt Inputs
......@@ -22,7 +27,7 @@ Consider that HF processors follow these main steps:
1. Tokenize the text
2. Process multi-modal inputs
3. Perform prompt replacement
3. Perform prompt updates
And we require that:
......@@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
(mm-automatic-prompt-replacement)=
(mm-automatic-prompt-updating)=
### Automatic prompt replacement
### Automatic prompt updating
We address the second issue by implementing model-agnostic code in
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
### Summary
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
## Processor Output Caching
......@@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238)
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other.
......@@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement,
PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
iter_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
......@@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"pattern_1": [],
"pattern_2": [],
}
},
),
(
[32000, 32000, 32000, 32000],
......@@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
def test_find_token_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
result = find_token_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
......@@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
def test_find_text_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
result = find_text_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
......@@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
(
"Image:<image>Image:<image><image>!",
......@@ -300,41 +312,47 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
{
PromptInsertion: {
0: "Image:<image>Image:<image><image>!",
1: "Image:<image><image><image>Image:<image><image>!?!?",
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
},
PromptReplacement: {
0: "Image:<image>Image:<image><image>!",
1: "<image><image>Image:<image><image>?!?",
2: "<image><image><image><image><image>?!?",
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
def test_find_replace_text(
def test_find_update_text(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_text_matches(
for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches(
prompt,
mm_matches,
{key: mm_count
......@@ -342,6 +360,8 @@ def test_find_replace_text(
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
......@@ -351,7 +371,7 @@ def test_find_replace_text(
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
# Tokenized test cases of `test_find_replace_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
......@@ -372,41 +392,47 @@ def test_find_replace_text(
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
{
PromptInsertion: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
},
PromptReplacement: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
def test_find_replace_tokens(
def test_find_update_tokens(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_token_matches(
for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches(
prompt,
mm_matches,
{key: mm_count
......@@ -414,6 +440,8 @@ def test_find_replace_tokens(
)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
......@@ -524,22 +552,24 @@ def test_find_replace_tokens(
),
]
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_mm_placeholders(
repl_by_key,
prompt,
expected,
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(
mm_prompt_repls,
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3
......
# SPDX-License-Identifier: Apache-2.0
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
pixel_mask=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
image_token_id = vocab["<image>"]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
target="",
insertion=image_tokens,
)
]
......
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
......@@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor(
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=([image_start_id] + image_tokens + [image_end_id]),
features=image_tokens,
),
......@@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states
......
......@@ -3,9 +3,9 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
......@@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
Set, Tuple, TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement,
PromptReplacementDetails)
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder(
class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
......@@ -850,26 +850,22 @@ class Florence2MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id
bos_token_id = hf_config.bos_token_id
num_image_tokens = self.info.get_max_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
target="",
insertion=image_tokens,
)
]
......
......@@ -17,8 +17,8 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
import torch
import torch.nn as nn
......@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
......@@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
......
......@@ -4,7 +4,8 @@
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from typing import Literal, Mapping, Optional, TypedDict, Union
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
......@@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
MultiModalFieldConfig,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
......@@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
......@@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id
......
......@@ -7,7 +7,8 @@
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Mapping, Optional
from collections.abc import Mapping, Sequence
from typing import Optional
import torch
from PIL import Image
......@@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
......@@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
......
......@@ -16,8 +16,8 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
......@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
......
......@@ -7,9 +7,10 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, TypeVar, Union)
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
Union)
import torch
import torch.nn as nn
......@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
......@@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
......
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
......@@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
......@@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
tokenizer = self.info.get_tokenizer()
......@@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens
])
mantis_mm_repls = self._bind_and_group_repls([
mantis_mm_repls = self._bind_and_group_updates([
PromptReplacement(
modality="image",
target=[image_token_id] * num_image_tokens,
......@@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
])
prompt_ids, prompt, _ = self._apply_prompt_replacements(
prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_replacements(
unbound_orig_repls = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
......@@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
......@@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor(
)
return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
base_result = super()._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor(
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
image_repls = super()._get_prompt_replacements(
) -> Sequence[PromptUpdate]:
image_repls = super()._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
......@@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor(
return [video_token_id] * num_video_tokens
return image_repls + [
return [
*image_repls,
PromptReplacement(
modality="video",
target=[video_token_id],
......
......@@ -22,9 +22,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
TypedDict, Union)
import torch
from torch import nn
......@@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_replacements(
def _get_prompt_updates(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = {
"image": self.info.image_pattern,
"video": self.info.video_pattern,
......
......@@ -25,9 +25,10 @@
import math
import re
from collections import Counter
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
......@@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
}
}
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
......@@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> bool:
return False
def _get_prompt_replacements(
def _get_prompt_updates(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = {
"image": self.info.image_pattern,
"video": self.info.video_pattern,
......
......@@ -15,8 +15,8 @@
# limitations under the License.
"""PyTorch Mllama model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
......@@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP
......@@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index
......
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union, cast)
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
import numpy as np
import torch
......@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
......@@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
# in multi-image case
return {"image": 1}
def get_mm_max_tokens_per_item(
......@@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size
user_str = "User:"
if processor.always_start_with_space:
user_str = " " + user_str
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
......@@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
[img_end_id])
def get_replacement_molmo(item_idx: int):
def get_insertion_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
......@@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint
return PromptReplacementDetails(
full=image_tokens + user_tokens,
features=image_tokens,
)
return image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=user_str,
replacement=get_replacement_molmo,
target="<|endoftext|>",
insertion=get_insertion_molmo,
)
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment