Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f7bee5c8
Unverified
Commit
f7bee5c8
authored
Feb 28, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 28, 2025
Browse files
[VLM][Bugfix] Enable specifying prompt target via index (#14038)
parent
e0734387
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
432 additions
and
59 deletions
+432
-59
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+256
-2
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+3
-3
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+3
-2
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+3
-3
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+167
-49
No files found.
tests/multimodal/test_processing.py
View file @
f7bee5c8
...
...
@@ -14,8 +14,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.multimodal.processing
import
(
PlaceholderFeaturesInfo
,
PromptIn
sertion
,
PromptReplacement
,
apply_text_matches
,
PromptIn
dexTargets
,
PromptInsertion
,
PromptReplacement
,
apply_text_matches
,
apply_token_matches
,
find_mm_placeholders
,
find_text_matches
,
find_token_matches
,
...
...
@@ -98,10 +98,20 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"pattern_1"
:
[],
"pattern_2"
:
[
32000
],
"pattern_3"
:
PromptIndexTargets
.
start
(),
"pattern_4"
:
PromptIndexTargets
.
prefix
([
32000
]),
"pattern_5"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[],
"pattern_2"
:
[],
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_4"
:
[],
"pattern_5"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
},
),
(
...
...
@@ -110,6 +120,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1"
:
[
32000
],
"pattern_2"
:
[
32000
,
32000
],
"pattern_3"
:
[
32000
,
32000
,
32000
],
"pattern_4"
:
PromptIndexTargets
.
start
(),
"pattern_5"
:
PromptIndexTargets
.
prefix
([
32000
]),
"pattern_6"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
...
...
@@ -125,6 +138,15 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
3
},
],
"pattern_4"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_5"
:
[
{
"start_idx"
:
1
,
"end_idx"
:
1
},
],
"pattern_6"
:
[
{
"start_idx"
:
4
,
"end_idx"
:
4
},
],
},
),
(
...
...
@@ -133,6 +155,9 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_1"
:
[
28747
,
32000
],
"pattern_2"
:
[
28747
,
32000
,
32000
,
32000
],
"pattern_3"
:
[
28747
,
0
,
32000
],
"pattern_4"
:
PromptIndexTargets
.
start
(),
"pattern_5"
:
PromptIndexTargets
.
prefix
([
28747
,
32000
]),
"pattern_6"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
...
...
@@ -143,6 +168,13 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"start_idx"
:
1
,
"end_idx"
:
5
},
],
"pattern_3"
:
[],
"pattern_4"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_5"
:
[],
"pattern_6"
:
[
{
"start_idx"
:
10
,
"end_idx"
:
10
},
],
},
),
],
...
...
@@ -189,10 +221,20 @@ def test_find_token_matches(
{
"pattern_1"
:
""
,
"pattern_2"
:
"<image>"
,
"pattern_3"
:
PromptIndexTargets
.
start
(),
"pattern_4"
:
PromptIndexTargets
.
prefix
(
"<image>"
),
"pattern_5"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[{
"start_idx"
:
0
,
"end_idx"
:
0
}],
"pattern_2"
:
[],
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_4"
:
[],
"pattern_5"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
}
),
(
...
...
@@ -201,6 +243,9 @@ def test_find_token_matches(
"pattern_1"
:
"<image>"
,
"pattern_2"
:
"<image><image>"
,
"pattern_3"
:
"<image><image><image>"
,
"pattern_4"
:
PromptIndexTargets
.
start
(),
"pattern_5"
:
PromptIndexTargets
.
prefix
(
"<image>"
),
"pattern_6"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
...
...
@@ -216,6 +261,15 @@ def test_find_token_matches(
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
21
},
],
"pattern_4"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_5"
:
[
{
"start_idx"
:
7
,
"end_idx"
:
7
},
],
"pattern_6"
:
[
{
"start_idx"
:
28
,
"end_idx"
:
28
},
],
},
),
(
...
...
@@ -224,6 +278,9 @@ def test_find_token_matches(
"pattern_1"
:
"Image:<image>"
,
"pattern_2"
:
"Image:<image><image><image>"
,
"pattern_3"
:
"Image:<unk><image>"
,
"pattern_4"
:
PromptIndexTargets
.
start
(),
"pattern_5"
:
PromptIndexTargets
.
prefix
(
"Image:<image>"
),
"pattern_6"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
...
...
@@ -234,6 +291,15 @@ def test_find_token_matches(
{
"start_idx"
:
0
,
"end_idx"
:
27
},
],
"pattern_3"
:
[],
"pattern_4"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
0
},
],
"pattern_5"
:
[
{
"start_idx"
:
13
,
"end_idx"
:
13
},
],
"pattern_6"
:
[
{
"start_idx"
:
48
,
"end_idx"
:
48
},
],
},
),
# Test regex escape
...
...
@@ -325,6 +391,100 @@ def test_find_text_matches(
},
},
),
# Test index targets
(
""
,
{
"pattern_1"
:
PromptIndexTargets
.
start
(),
"pattern_2"
:
PromptIndexTargets
.
prefix
(
"<image>"
),
"pattern_3"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
"1"
,
"pattern_2"
:
"2"
,
"pattern_3"
:
"3"
,
},
{
PromptInsertion
:
{
0
:
""
,
1
:
"13"
,
2
:
"1133"
,
},
PromptReplacement
:
{
0
:
""
,
1
:
"13"
,
2
:
"1133"
,
},
},
),
(
"<image>"
,
{
"pattern_1"
:
PromptIndexTargets
.
start
(),
"pattern_2"
:
PromptIndexTargets
.
prefix
(
"<image>"
),
"pattern_3"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
"1"
,
"pattern_2"
:
"2"
,
"pattern_3"
:
"3"
,
},
{
PromptInsertion
:
{
0
:
"<image>"
,
1
:
"1<image>23"
,
2
:
"11<image>2233"
,
},
PromptReplacement
:
{
0
:
"<image>"
,
1
:
"1<image>23"
,
2
:
"11<image>2233"
,
},
},
),
# Test different replacement per item
(
"<image><image><image>"
,
{
"pattern_1"
:
"<image>"
,
},
{
"pattern_1"
:
lambda
idx
:
str
(
idx
+
1
),
},
{
PromptInsertion
:
{
0
:
"<image><image><image>"
,
1
:
"<image>1<image><image>"
,
2
:
"<image>12<image><image>"
,
},
PromptReplacement
:
{
0
:
"<image><image><image>"
,
1
:
"1<image><image>"
,
2
:
"12<image>"
,
},
},
),
(
"<image><image><image>"
,
{
"pattern_1"
:
PromptIndexTargets
.
prefix
(
"<image>"
),
},
{
"pattern_1"
:
lambda
idx
:
str
(
idx
+
1
),
},
{
PromptInsertion
:
{
0
:
"<image><image><image>"
,
1
:
"<image>1<image><image>"
,
2
:
"<image>12<image><image>"
,
},
PromptReplacement
:
{
0
:
"<image><image><image>"
,
1
:
"<image>1<image><image>"
,
2
:
"<image>12<image><image>"
,
},
},
),
]
)
# yapf: enable
...
...
@@ -405,6 +565,100 @@ def test_find_update_text(
},
},
),
# Test index targets
(
[],
{
"pattern_1"
:
PromptIndexTargets
.
start
(),
"pattern_2"
:
PromptIndexTargets
.
prefix
([
32000
]),
"pattern_3"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
-
1
],
"pattern_2"
:
[
-
2
],
"pattern_3"
:
[
-
3
],
},
{
PromptInsertion
:
{
0
:
[],
1
:
[
-
1
,
-
3
],
2
:
[
-
1
,
-
1
,
-
3
,
-
3
],
},
PromptReplacement
:
{
0
:
[],
1
:
[
-
1
,
-
3
],
2
:
[
-
1
,
-
1
,
-
3
,
-
3
],
},
},
),
(
[
32000
],
{
"pattern_1"
:
PromptIndexTargets
.
start
(),
"pattern_2"
:
PromptIndexTargets
.
prefix
([
32000
]),
"pattern_3"
:
PromptIndexTargets
.
end
(),
},
{
"pattern_1"
:
[
-
1
],
"pattern_2"
:
[
-
2
],
"pattern_3"
:
[
-
3
],
},
{
PromptInsertion
:
{
0
:
[
32000
],
1
:
[
-
1
,
32000
,
-
2
,
-
3
],
2
:
[
-
1
,
-
1
,
32000
,
-
2
,
-
2
,
-
3
,
-
3
],
},
PromptReplacement
:
{
0
:
[
32000
],
1
:
[
-
1
,
32000
,
-
2
,
-
3
],
2
:
[
-
1
,
-
1
,
32000
,
-
2
,
-
2
,
-
3
,
-
3
],
},
},
),
# Test different replacement per item
(
[
32000
,
32000
,
32000
],
{
"pattern_1"
:
[
32000
],
},
{
"pattern_1"
:
lambda
idx
:
[
-
(
idx
+
1
)],
},
{
PromptInsertion
:
{
0
:
[
32000
,
32000
,
32000
],
1
:
[
32000
,
-
1
,
32000
,
32000
],
2
:
[
32000
,
-
1
,
-
2
,
32000
,
32000
],
},
PromptReplacement
:
{
0
:
[
32000
,
32000
,
32000
],
1
:
[
-
1
,
32000
,
32000
],
2
:
[
-
1
,
-
2
,
32000
],
},
},
),
(
[
32000
,
32000
,
32000
],
{
"pattern_1"
:
PromptIndexTargets
.
prefix
([
32000
]),
},
{
"pattern_1"
:
lambda
idx
:
[
-
(
idx
+
1
)],
},
{
PromptInsertion
:
{
0
:
[
32000
,
32000
,
32000
],
1
:
[
32000
,
-
1
,
32000
,
32000
],
2
:
[
32000
,
-
1
,
-
2
,
32000
,
32000
],
},
PromptReplacement
:
{
0
:
[
32000
,
32000
,
32000
],
1
:
[
32000
,
-
1
,
32000
,
32000
],
2
:
[
32000
,
-
1
,
-
2
,
32000
,
32000
],
},
},
),
]
)
# yapf: enable
...
...
vllm/model_executor/models/blip2.py
View file @
f7bee5c8
...
...
@@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIn
sertion
,
PromptUpdate
)
BaseProcessingInfo
,
PromptIn
dexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -490,7 +490,7 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
return
[
PromptInsertion
(
modality
=
"image"
,
target
=
""
,
target
=
PromptIndexTargets
.
start
()
,
insertion
=
image_tokens
,
)
]
...
...
vllm/model_executor/models/florence2.py
View file @
f7bee5c8
...
...
@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from
vllm.multimodal.parse
import
MultiModalDataDict
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptInsertion
,
PromptUpdate
)
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -864,7 +865,7 @@ class Florence2MultiModalProcessor(
return
[
PromptInsertion
(
modality
=
"image"
,
target
=
""
,
target
=
PromptIndexTargets
.
start
()
,
insertion
=
image_tokens
,
)
]
...
...
vllm/model_executor/models/molmo.py
View file @
f7bee5c8
...
...
@@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptIn
sertion
,
PromptUpdate
)
BaseProcessingInfo
,
PromptIn
dexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
JSONTree
,
json_map_leaves
...
...
@@ -1371,7 +1371,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
return
[
PromptInsertion
(
modality
=
"image"
,
target
=
"<|endoftext|>"
,
target
=
PromptIndexTargets
.
prefix
(
"<|endoftext|>"
)
,
insertion
=
get_insertion_molmo
,
)
]
...
...
vllm/multimodal/processing.py
View file @
f7bee5c8
...
...
@@ -8,7 +8,6 @@ from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
from
functools
import
lru_cache
from
itertools
import
groupby
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
,
cast
)
...
...
@@ -40,6 +39,65 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
@
dataclass
class
PromptIndex
:
"""Resolves to an index in the prompt."""
get_match_index
:
Callable
[[
AnyTokenizer
,
PromptSeq
],
Optional
[
int
]]
class
PromptIndexTargets
:
@
staticmethod
def
start
()
->
PromptIndex
:
"""
Resolves to the start of the prompt (before the first token).
This results in a match even if the prompt is empty.
"""
return
PromptIndex
(
lambda
tok
,
prompt
:
0
)
@
staticmethod
def
prefix
(
seq
:
PromptSeq
)
->
PromptIndex
:
"""
Resolves to a location in the prompt after the given prefix.
"""
def
get_match_index
(
tokenizer
:
AnyTokenizer
,
prompt
:
PromptSeq
,
)
->
Optional
[
int
]:
prefix
=
seq
if
isinstance
(
prompt
,
str
):
if
not
isinstance
(
prefix
,
str
):
# Make both `str`
prefix
=
decode_tokens
(
tokenizer
,
prefix
)
else
:
if
isinstance
(
prefix
,
str
):
# Make both `list[int]`
prefix
=
encode_tokens
(
tokenizer
,
prefix
)
match_idx
=
len
(
prefix
)
return
match_idx
if
prompt
[:
match_idx
]
==
prefix
else
None
return
PromptIndex
(
get_match_index
)
@
staticmethod
def
end
()
->
PromptIndex
:
"""
Resolves to the end of the prompt (after the last token).
This results in a match even if the prompt is empty.
"""
return
PromptIndex
(
lambda
tok
,
prompt
:
len
(
prompt
))
PromptTarget
=
Union
[
PromptSeq
,
PromptIndex
]
"""
The token sequence or text to update.
"""
@
dataclass
class
PromptUpdateDetails
:
"""Details about the token sequence or text that are part of the update."""
...
...
@@ -84,7 +142,7 @@ class UpdateMode(str, Enum):
@
dataclass
class
PromptUpdate
:
class
PromptUpdate
(
ABC
)
:
"""
Defines how to update a prompt with placeholder tokens.
"""
...
...
@@ -92,7 +150,7 @@ class PromptUpdate:
modality
:
str
"""The modality for which the update is made."""
target
:
Prompt
Seq
target
:
Prompt
Target
"""The token sequence (or text) to update."""
@
property
...
...
@@ -122,24 +180,43 @@ class PromptInsertion(PromptUpdate):
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
equal to the feature size of the vision encoder after the ``<s>`` token:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
target="
<s>
",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token
:
Insert these tokens at the start of the prompt
:
.. code-block:: python
PromptInsertion(
modality="image",
target="<s>",
target=PromptIndexTargets.start(),
insertion="<image>" * image_feature_size,
)
Insert these tokens after a prefix ``Images:``:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.prefix("Images:"),
insertion="<image>" * image_feature_size,
)
Insert these tokens at the end of the prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target=PromptIndexTargets.end(),
insertion="<image>" * image_feature_size,
)
"""
...
...
@@ -345,10 +422,14 @@ class BoundPromptUpdate:
return
self
.
_origin
.
modality
@
property
def
target
(
self
)
->
_BoundPromptSequence
:
def
target
(
self
)
->
Union
[
_BoundPromptSequence
,
PromptIndex
]
:
"""The token sequence (or text) to update."""
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
self
.
_origin
.
target
)
target
=
self
.
_origin
.
target
if
isinstance
(
target
,
PromptIndex
):
return
target
return
_BoundPromptSequence
.
from_seq
(
self
.
tokenizer
,
target
)
@
property
def
content
(
self
)
->
PromptUpdateContent
:
...
...
@@ -447,6 +528,19 @@ class _PromptTargetMatch(ABC):
f
"start_idx=
{
self
.
start_idx
!
r
}
, end_idx=
{
self
.
end_idx
!
r
}
)"
)
@
dataclass
(
repr
=
False
)
class
_PromptTargetIndexMatch
(
_PromptTargetMatch
):
match_idx
:
int
@
property
def
start_idx
(
self
)
->
int
:
return
self
.
match_idx
@
property
def
end_idx
(
self
)
->
int
:
return
self
.
match_idx
@
dataclass
(
repr
=
False
)
class
_PromptTargetTokenMatch
(
_PromptTargetMatch
):
match
:
_TokenMatch
...
...
@@ -496,9 +590,24 @@ def find_token_matches(
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
)
->
Sequence
[
_PromptTargetMatch
]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def
get_matches
(
update
:
BoundPromptUpdate
):
target
=
update
.
target
if
isinstance
(
target
,
PromptIndex
):
match_idx
=
target
.
get_match_index
(
update
.
tokenizer
,
prompt
)
if
match_idx
is
None
:
return
[]
return
[
_PromptTargetIndexMatch
(
update
,
match_idx
)]
return
[
_PromptTargetTokenMatch
(
update
,
match
)
for
update
in
prompt_updates
for
match
in
iter_token_matches
(
prompt
,
update
.
target
.
token_ids
)
_PromptTargetTokenMatch
(
update
,
match
)
for
match
in
iter_token_matches
(
prompt
,
target
.
token_ids
)
]
return
[
match
for
update
in
prompt_updates
for
match
in
get_matches
(
update
)
]
...
...
@@ -507,9 +616,24 @@ def find_text_matches(
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
)
->
Sequence
[
_PromptTargetMatch
]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def
get_matches
(
update
:
BoundPromptUpdate
):
target
=
update
.
target
if
isinstance
(
target
,
PromptIndex
):
match_idx
=
target
.
get_match_index
(
update
.
tokenizer
,
prompt
)
if
match_idx
is
None
:
return
[]
return
[
_PromptTargetIndexMatch
(
update
,
match_idx
)]
return
[
_PromptTargetTextMatch
(
update
,
match
)
for
match
in
re
.
finditer
(
re
.
escape
(
target
.
text
),
prompt
)
]
return
[
_PromptTargetTextMatch
(
update
,
match
)
for
update
in
prompt_updates
for
match
in
re
.
finditer
(
re
.
escape
(
update
.
target
.
text
),
prompt
)
match
for
update
in
prompt_updates
for
match
in
get_matches
(
update
)
]
...
...
@@ -547,45 +671,39 @@ def _apply_matches(
prev_end_idx
=
0
next_idx_by_modality
=
defaultdict
[
str
,
int
](
lambda
:
0
)
for
(
start_idx
,
end_idx
),
group
in
groupby
(
_resolve_matches
(
prompt
,
mm_matches
),
key
=
lambda
x
:
(
x
.
start_idx
,
x
.
end_idx
),
):
matches
=
tuple
(
group
)
assert
len
(
matches
)
==
1
for
match
in
matches
:
for
match
in
_resolve_matches
(
prompt
,
mm_matches
):
modality
=
match
.
modality
item_idx
=
next_idx_by_modality
[
modality
]
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
item_start_idx
=
next_idx_by_modality
[
modality
]
max_item_count
=
mm_item_counts
.
get
(
modality
,
0
)
if
item_start_idx
>=
max_item_count
:
continue
start_idx
=
match
.
start_idx
end_idx
=
match
.
end_idx
origin
=
match
.
_origin
content
=
origin
.
get_content
(
item_idx
)
mode
=
origin
.
mode
if
mode
==
UpdateMode
.
INSERT
:
out_seqs
.
append
(
prompt
[
prev_end_idx
:
end_idx
])
num_inserts
=
m
m
_item_count
s
.
get
(
modality
,
0
)
num_inserts
=
m
ax
_item_count
elif
mode
==
UpdateMode
.
REPLACE
:
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
])
num_inserts
=
1
num_inserts
=
max_item_count
if
start_idx
==
end_idx
else
1
else
:
assert_never
(
mode
)
for
_
in
range
(
num_inserts
):
if
item_idx
>=
mm_item_counts
.
get
(
modality
,
0
):
continue
item_end_idx
=
min
(
item_start_idx
+
num_inserts
,
max_item_count
)
if
isinst
an
c
e
(
prompt
,
str
):
out_seqs
.
append
(
content
.
full
.
text
)
else
:
out_seqs
.
append
(
content
.
full
.
token_ids
)
for
item_idx
in
r
an
g
e
(
item_start_idx
,
item_end_idx
):
content
=
origin
.
get_content
(
item_idx
)
insert_seq
=
(
content
.
full
.
text
if
isinstance
(
prompt
,
str
)
else
content
.
full
.
token_ids
)
next_idx_by_modality
[
modality
]
+=
1
out_seqs
.
append
(
insert_seq
)
prev_end_idx
=
end_idx
next_idx_by_modality
[
modality
]
+=
item_end_idx
-
item_start_idx
out_seqs
.
append
(
prompt
[
prev_end_idx
:])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment