Unverified Commit f7bee5c8 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM][Bugfix] Enable specifying prompt target via index (#14038)

parent e0734387
...@@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo, from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptInsertion, PromptReplacement, PromptIndexTargets, PromptInsertion,
apply_text_matches, PromptReplacement, apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
...@@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{ {
"pattern_1": [], "pattern_1": [],
"pattern_2": [32000], "pattern_2": [32000],
"pattern_3": PromptIndexTargets.start(),
"pattern_4": PromptIndexTargets.prefix([32000]),
"pattern_5": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [], "pattern_1": [],
"pattern_2": [], "pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
],
}, },
), ),
( (
...@@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1": [32000], "pattern_1": [32000],
"pattern_2": [32000, 32000], "pattern_2": [32000, 32000],
"pattern_3": [32000, 32000, 32000], "pattern_3": [32000, 32000, 32000],
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix([32000]),
"pattern_6": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [ "pattern_1": [
...@@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_3": [ "pattern_3": [
{ "start_idx": 0, "end_idx": 3 }, { "start_idx": 0, "end_idx": 3 },
], ],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 1, "end_idx": 1 },
],
"pattern_6": [
{ "start_idx": 4, "end_idx": 4 },
],
}, },
), ),
( (
...@@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1": [28747, 32000], "pattern_1": [28747, 32000],
"pattern_2": [28747, 32000, 32000, 32000], "pattern_2": [28747, 32000, 32000, 32000],
"pattern_3": [28747, 0, 32000], "pattern_3": [28747, 0, 32000],
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix([28747, 32000]),
"pattern_6": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [ "pattern_1": [
...@@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected): ...@@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{ "start_idx": 1, "end_idx": 5 }, { "start_idx": 1, "end_idx": 5 },
], ],
"pattern_3": [], "pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [],
"pattern_6": [
{ "start_idx": 10, "end_idx": 10 },
],
}, },
), ),
], ],
...@@ -189,10 +221,20 @@ def test_find_token_matches( ...@@ -189,10 +221,20 @@ def test_find_token_matches(
{ {
"pattern_1": "", "pattern_1": "",
"pattern_2": "<image>", "pattern_2": "<image>",
"pattern_3": PromptIndexTargets.start(),
"pattern_4": PromptIndexTargets.prefix("<image>"),
"pattern_5": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }], "pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_2": [], "pattern_2": [],
"pattern_3": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_4": [],
"pattern_5": [
{ "start_idx": 0, "end_idx": 0 },
],
} }
), ),
( (
...@@ -201,6 +243,9 @@ def test_find_token_matches( ...@@ -201,6 +243,9 @@ def test_find_token_matches(
"pattern_1": "<image>", "pattern_1": "<image>",
"pattern_2": "<image><image>", "pattern_2": "<image><image>",
"pattern_3": "<image><image><image>", "pattern_3": "<image><image><image>",
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix("<image>"),
"pattern_6": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [ "pattern_1": [
...@@ -216,6 +261,15 @@ def test_find_token_matches( ...@@ -216,6 +261,15 @@ def test_find_token_matches(
"pattern_3": [ "pattern_3": [
{ "start_idx": 0, "end_idx": 21 }, { "start_idx": 0, "end_idx": 21 },
], ],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 7, "end_idx": 7 },
],
"pattern_6": [
{ "start_idx": 28, "end_idx": 28 },
],
}, },
), ),
( (
...@@ -224,6 +278,9 @@ def test_find_token_matches( ...@@ -224,6 +278,9 @@ def test_find_token_matches(
"pattern_1": "Image:<image>", "pattern_1": "Image:<image>",
"pattern_2": "Image:<image><image><image>", "pattern_2": "Image:<image><image><image>",
"pattern_3": "Image:<unk><image>", "pattern_3": "Image:<unk><image>",
"pattern_4": PromptIndexTargets.start(),
"pattern_5": PromptIndexTargets.prefix("Image:<image>"),
"pattern_6": PromptIndexTargets.end(),
}, },
{ {
"pattern_1": [ "pattern_1": [
...@@ -234,6 +291,15 @@ def test_find_token_matches( ...@@ -234,6 +291,15 @@ def test_find_token_matches(
{ "start_idx": 0, "end_idx": 27 }, { "start_idx": 0, "end_idx": 27 },
], ],
"pattern_3": [], "pattern_3": [],
"pattern_4": [
{ "start_idx": 0, "end_idx": 0 },
],
"pattern_5": [
{ "start_idx": 13, "end_idx": 13 },
],
"pattern_6": [
{ "start_idx": 48, "end_idx": 48 },
],
}, },
), ),
# Test regex escape # Test regex escape
...@@ -325,6 +391,100 @@ def test_find_text_matches( ...@@ -325,6 +391,100 @@ def test_find_text_matches(
}, },
}, },
), ),
# Test index targets
(
"",
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix("<image>"),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": "1",
"pattern_2": "2",
"pattern_3": "3",
},
{
PromptInsertion: {
0: "",
1: "13",
2: "1133",
},
PromptReplacement: {
0: "",
1: "13",
2: "1133",
},
},
),
(
"<image>",
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix("<image>"),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": "1",
"pattern_2": "2",
"pattern_3": "3",
},
{
PromptInsertion: {
0: "<image>",
1: "1<image>23",
2: "11<image>2233",
},
PromptReplacement: {
0: "<image>",
1: "1<image>23",
2: "11<image>2233",
},
},
),
# Test different replacement per item
(
"<image><image><image>",
{
"pattern_1": "<image>",
},
{
"pattern_1": lambda idx: str(idx + 1),
},
{
PromptInsertion: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
PromptReplacement: {
0: "<image><image><image>",
1: "1<image><image>",
2: "12<image>",
},
},
),
(
"<image><image><image>",
{
"pattern_1": PromptIndexTargets.prefix("<image>"),
},
{
"pattern_1": lambda idx: str(idx + 1),
},
{
PromptInsertion: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
PromptReplacement: {
0: "<image><image><image>",
1: "<image>1<image><image>",
2: "<image>12<image><image>",
},
},
),
] ]
) )
# yapf: enable # yapf: enable
...@@ -405,6 +565,100 @@ def test_find_update_text( ...@@ -405,6 +565,100 @@ def test_find_update_text(
}, },
}, },
), ),
# Test index targets
(
[],
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix([32000]),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": [-1],
"pattern_2": [-2],
"pattern_3": [-3],
},
{
PromptInsertion: {
0: [],
1: [-1, -3],
2: [-1, -1, -3, -3],
},
PromptReplacement: {
0: [],
1: [-1, -3],
2: [-1, -1, -3, -3],
},
},
),
(
[32000],
{
"pattern_1": PromptIndexTargets.start(),
"pattern_2": PromptIndexTargets.prefix([32000]),
"pattern_3": PromptIndexTargets.end(),
},
{
"pattern_1": [-1],
"pattern_2": [-2],
"pattern_3": [-3],
},
{
PromptInsertion: {
0: [32000],
1: [-1, 32000, -2, -3],
2: [-1, -1, 32000, -2, -2, -3, -3],
},
PromptReplacement: {
0: [32000],
1: [-1, 32000, -2, -3],
2: [-1, -1, 32000, -2, -2, -3, -3],
},
},
),
# Test different replacement per item
(
[32000, 32000, 32000],
{
"pattern_1": [32000],
},
{
"pattern_1": lambda idx: [-(idx + 1)],
},
{
PromptInsertion: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
PromptReplacement: {
0: [32000, 32000, 32000],
1: [-1, 32000, 32000],
2: [-1, -2, 32000],
},
},
),
(
[32000, 32000, 32000],
{
"pattern_1": PromptIndexTargets.prefix([32000]),
},
{
"pattern_1": lambda idx: [-(idx + 1)],
},
{
PromptInsertion: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
PromptReplacement: {
0: [32000, 32000, 32000],
1: [32000, -1, 32000, 32000],
2: [32000, -1, -2, 32000, 32000],
},
},
),
] ]
) )
# yapf: enable # yapf: enable
......
...@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors) NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptInsertion, BaseProcessingInfo, PromptIndexTargets,
PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]): ...@@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
return [ return [
PromptInsertion( PromptInsertion(
modality="image", modality="image",
target="", target=PromptIndexTargets.start(),
insertion=image_tokens, insertion=image_tokens,
) )
] ]
......
...@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs ...@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo, from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor, EncDecMultiModalProcessor,
PromptInsertion, PromptUpdate) PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -864,7 +865,7 @@ class Florence2MultiModalProcessor( ...@@ -864,7 +865,7 @@ class Florence2MultiModalProcessor(
return [ return [
PromptInsertion( PromptInsertion(
modality="image", modality="image",
target="", target=PromptIndexTargets.start(),
insertion=image_tokens, insertion=image_tokens,
) )
] ]
......
...@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs, ...@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptInsertion, BaseProcessingInfo, PromptIndexTargets,
PromptUpdate) PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves from vllm.utils import JSONTree, json_map_leaves
...@@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): ...@@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return [ return [
PromptInsertion( PromptInsertion(
modality="image", modality="image",
target="<|endoftext|>", target=PromptIndexTargets.prefix("<|endoftext|>"),
insertion=get_insertion_molmo, insertion=get_insertion_molmo,
) )
] ]
......
...@@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, ...@@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from itertools import groupby
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
...@@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]] ...@@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text.""" """A token sequence (list of token IDs) or text."""
@dataclass
class PromptIndex:
"""Resolves to an index in the prompt."""
get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]]
class PromptIndexTargets:
@staticmethod
def start() -> PromptIndex:
"""
Resolves to the start of the prompt (before the first token).
This results in a match even if the prompt is empty.
"""
return PromptIndex(lambda tok, prompt: 0)
@staticmethod
def prefix(seq: PromptSeq) -> PromptIndex:
"""
Resolves to a location in the prompt after the given prefix.
"""
def get_match_index(
tokenizer: AnyTokenizer,
prompt: PromptSeq,
) -> Optional[int]:
prefix = seq
if isinstance(prompt, str):
if not isinstance(prefix, str):
# Make both `str`
prefix = decode_tokens(tokenizer, prefix)
else:
if isinstance(prefix, str):
# Make both `list[int]`
prefix = encode_tokens(tokenizer, prefix)
match_idx = len(prefix)
return match_idx if prompt[:match_idx] == prefix else None
return PromptIndex(get_match_index)
@staticmethod
def end() -> PromptIndex:
"""
Resolves to the end of the prompt (after the last token).
This results in a match even if the prompt is empty.
"""
return PromptIndex(lambda tok, prompt: len(prompt))
PromptTarget = Union[PromptSeq, PromptIndex]
"""
The token sequence or text to update.
"""
@dataclass @dataclass
class PromptUpdateDetails: class PromptUpdateDetails:
"""Details about the token sequence or text that are part of the update.""" """Details about the token sequence or text that are part of the update."""
...@@ -84,7 +142,7 @@ class UpdateMode(str, Enum): ...@@ -84,7 +142,7 @@ class UpdateMode(str, Enum):
@dataclass @dataclass
class PromptUpdate: class PromptUpdate(ABC):
""" """
Defines how to update a prompt with placeholder tokens. Defines how to update a prompt with placeholder tokens.
""" """
...@@ -92,7 +150,7 @@ class PromptUpdate: ...@@ -92,7 +150,7 @@ class PromptUpdate:
modality: str modality: str
"""The modality for which the update is made.""" """The modality for which the update is made."""
target: PromptSeq target: PromptTarget
"""The token sequence (or text) to update.""" """The token sequence (or text) to update."""
@property @property
...@@ -122,24 +180,43 @@ class PromptInsertion(PromptUpdate): ...@@ -122,24 +180,43 @@ class PromptInsertion(PromptUpdate):
Example: Example:
For each image, insert a number of ``<image>`` feature placeholders For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the equal to the feature size of the vision encoder after the ``<s>`` token:
prompt:
.. code-block:: python .. code-block:: python
PromptInsertion( PromptInsertion(
modality="image", modality="image",
target="", target="<s>",
insertion="<image>" * image_feature_size, insertion="<image>" * image_feature_size,
) )
As above, but insert after the ``<s>`` token: Insert these tokens at the start of the prompt:
.. code-block:: python .. code-block:: python
PromptInsertion( PromptInsertion(
modality="image", modality="image",
target="<s>", target=PromptIndexTargets.start(),
insertion="<image>" * image_feature_size,
)
Insert these tokens after a prefix ``Images:``:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix("Images:"),
insertion="<image>" * image_feature_size,
)
Insert these tokens at the end of the prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.end(),
insertion="<image>" * image_feature_size, insertion="<image>" * image_feature_size,
) )
""" """
...@@ -345,10 +422,14 @@ class BoundPromptUpdate: ...@@ -345,10 +422,14 @@ class BoundPromptUpdate:
return self._origin.modality return self._origin.modality
@property @property
def target(self) -> _BoundPromptSequence: def target(self) -> Union[_BoundPromptSequence, PromptIndex]:
"""The token sequence (or text) to update.""" """The token sequence (or text) to update."""
return _BoundPromptSequence.from_seq(self.tokenizer, target = self._origin.target
self._origin.target)
if isinstance(target, PromptIndex):
return target
return _BoundPromptSequence.from_seq(self.tokenizer, target)
@property @property
def content(self) -> PromptUpdateContent: def content(self) -> PromptUpdateContent:
...@@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC): ...@@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC):
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False)
class _PromptTargetIndexMatch(_PromptTargetMatch):
match_idx: int
@property
def start_idx(self) -> int:
return self.match_idx
@property
def end_idx(self) -> int:
return self.match_idx
@dataclass(repr=False) @dataclass(repr=False)
class _PromptTargetTokenMatch(_PromptTargetMatch): class _PromptTargetTokenMatch(_PromptTargetMatch):
match: _TokenMatch match: _TokenMatch
...@@ -496,9 +590,24 @@ def find_token_matches( ...@@ -496,9 +590,24 @@ def find_token_matches(
prompt_updates: Sequence[BoundPromptUpdate], prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]: ) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`.""" """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
target = update.target
if isinstance(target, PromptIndex):
match_idx = target.get_match_index(update.tokenizer, prompt)
if match_idx is None:
return []
return [_PromptTargetIndexMatch(update, match_idx)]
return [
_PromptTargetTokenMatch(update, match)
for match in iter_token_matches(prompt, target.token_ids)
]
return [ return [
_PromptTargetTokenMatch(update, match) for update in prompt_updates match for update in prompt_updates for match in get_matches(update)
for match in iter_token_matches(prompt, update.target.token_ids)
] ]
...@@ -507,9 +616,24 @@ def find_text_matches( ...@@ -507,9 +616,24 @@ def find_text_matches(
prompt_updates: Sequence[BoundPromptUpdate], prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]: ) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`.""" """Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def get_matches(update: BoundPromptUpdate):
target = update.target
if isinstance(target, PromptIndex):
match_idx = target.get_match_index(update.tokenizer, prompt)
if match_idx is None:
return []
return [_PromptTargetIndexMatch(update, match_idx)]
return [
_PromptTargetTextMatch(update, match)
for match in re.finditer(re.escape(target.text), prompt)
]
return [ return [
_PromptTargetTextMatch(update, match) for update in prompt_updates match for update in prompt_updates for match in get_matches(update)
for match in re.finditer(re.escape(update.target.text), prompt)
] ]
...@@ -547,45 +671,39 @@ def _apply_matches( ...@@ -547,45 +671,39 @@ def _apply_matches(
prev_end_idx = 0 prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0) next_idx_by_modality = defaultdict[str, int](lambda: 0)
for (start_idx, end_idx), group in groupby( for match in _resolve_matches(prompt, mm_matches):
_resolve_matches(prompt, mm_matches), modality = match.modality
key=lambda x: (x.start_idx, x.end_idx),
): item_start_idx = next_idx_by_modality[modality]
matches = tuple(group) max_item_count = mm_item_counts.get(modality, 0)
assert len(matches) == 1 if item_start_idx >= max_item_count:
continue
for match in matches:
modality = match.modality start_idx = match.start_idx
end_idx = match.end_idx
origin = match._origin
mode = origin.mode
if mode == UpdateMode.INSERT:
out_seqs.append(prompt[prev_end_idx:end_idx])
num_inserts = max_item_count
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = max_item_count if start_idx == end_idx else 1
else:
assert_never(mode)
item_idx = next_idx_by_modality[modality] item_end_idx = min(item_start_idx + num_inserts, max_item_count)
if item_idx >= mm_item_counts.get(modality, 0):
continue
origin = match._origin for item_idx in range(item_start_idx, item_end_idx):
content = origin.get_content(item_idx) content = origin.get_content(item_idx)
mode = origin.mode insert_seq = (content.full.text if isinstance(prompt, str) else
content.full.token_ids)
if mode == UpdateMode.INSERT:
out_seqs.append(prompt[prev_end_idx:end_idx])
num_inserts = mm_item_counts.get(modality, 0)
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = 1
else:
assert_never(mode)
for _ in range(num_inserts):
if item_idx >= mm_item_counts.get(modality, 0):
continue
if isinstance(prompt, str):
out_seqs.append(content.full.text)
else:
out_seqs.append(content.full.token_ids)
next_idx_by_modality[modality] += 1 out_seqs.append(insert_seq)
prev_end_idx = end_idx prev_end_idx = end_idx
next_idx_by_modality[modality] += item_end_idx - item_start_idx
out_seqs.append(prompt[prev_end_idx:]) out_seqs.append(prompt[prev_end_idx:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment