Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8acd805
Unverified
Commit
c8acd805
authored
Nov 23, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 22, 2024
Browse files
[2/N] handling placeholders in merged multi-modal processor (#10485)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
4634a89d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
975 additions
and
147 deletions
+975
-147
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+370
-0
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+2
-1
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+1
-8
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+583
-137
vllm/utils.py
vllm/utils.py
+19
-1
No files found.
tests/multimodal/test_processing.py
0 → 100644
View file @
c8acd805
from
typing
import
cast
import
pytest
from
transformers
import
BatchFeature
from
vllm.multimodal.processing
import
(
PromptReplacement
,
find_text_matches
,
find_token_matches
,
iter_token_matches
,
iter_token_runs
,
replace_text_matches
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
full_groupby
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"token_ids"
,
"expected"
),
[
([],
[]),
(
[
32000
,
32000
,
32000
],
[{
"token_id"
:
32000
,
"start_idx"
:
0
,
"length"
:
3
}],
),
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
{
"token_id"
:
9833
,
"start_idx"
:
0
,
"length"
:
1
},
{
"token_id"
:
28747
,
"start_idx"
:
1
,
"length"
:
1
},
{
"token_id"
:
32000
,
"start_idx"
:
2
,
"length"
:
3
},
{
"token_id"
:
9833
,
"start_idx"
:
5
,
"length"
:
1
},
{
"token_id"
:
28747
,
"start_idx"
:
6
,
"length"
:
1
},
{
"token_id"
:
32000
,
"start_idx"
:
7
,
"length"
:
2
},
{
"token_id"
:
918
,
"start_idx"
:
9
,
"length"
:
1
},
],
),
],
)
# yapf: enable
def
test_iter_token_runs
(
token_ids
,
expected
):
result
=
list
(
iter_token_runs
(
token_ids
))
# Only displayed on error
print
(
"result:"
,
result
)
# Manually constructed results
assert
[
item
.
_asdict
()
for
item
in
result
]
==
expected
# Invariants
assert
sum
(
run_info
.
length
for
run_info
in
result
)
==
len
(
token_ids
)
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"token_ids"
,
"match_ids"
,
"expected"
),
[
([],
[],
[{
"start_idx"
:
0
,
"end_idx"
:
0
}]),
([],
[
32000
],
[]),
(
[
32000
,
32000
,
32000
],
[
32000
],
[
{
"start_idx"
:
0
,
"end_idx"
:
1
},
{
"start_idx"
:
1
,
"end_idx"
:
2
},
{
"start_idx"
:
2
,
"end_idx"
:
3
},
],
),
(
[
32000
,
32000
,
32000
],
[
32000
,
32000
],
[{
"start_idx"
:
0
,
"end_idx"
:
2
}],
),
(
[
32000
,
32000
,
32000
],
[
32000
,
32000
,
32000
],
[{
"start_idx"
:
0
,
"end_idx"
:
3
}],
),
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
28747
,
32000
],
[
{
"start_idx"
:
1
,
"end_idx"
:
3
},
{
"start_idx"
:
6
,
"end_idx"
:
8
},
],
),
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
28747
,
32000
,
32000
,
32000
],
[
{
"start_idx"
:
1
,
"end_idx"
:
5
},
],
),
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
[
28747
,
0
,
32000
],
[],
),
],
)
# yapf: enable
def
test_iter_token_matches
(
token_ids
,
match_ids
,
expected
):
result
=
list
(
iter_token_matches
(
token_ids
,
match_ids
))
# Manually constructed results
assert
[
item
.
_asdict
()
for
item
in
result
]
==
expected
# Invariants
match_lens
=
[
end
-
start
for
start
,
end
in
result
]
print
(
"match_lens:"
,
match_lens
)
# Only displayed on error
assert
all
(
match_len
==
len
(
match_ids
)
for
match_len
in
match_lens
)
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"prompt"
,
"target_by_key"
,
"expected_by_key"
),
[
(
[],
{
"pattern_1"
:
[],
"pattern_2"
:
[
32000
],
},
{
"pattern_1"
:
[{
"start_idx"
:
0
,
"end_idx"
:
0
}],
"pattern_2"
:
[],
}
),
(
[
32000
,
32000
,
32000
,
32000
],
{
"pattern_1"
:
[
32000
],
"pattern_2"
:
[
32000
,
32000
],
"pattern_3"
:
[
32000
,
32000
,
32000
],
},
{
"pattern_1"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
1
},
{
"start_idx"
:
1
,
"end_idx"
:
2
},
{
"start_idx"
:
2
,
"end_idx"
:
3
},
{
"start_idx"
:
3
,
"end_idx"
:
4
},
],
"pattern_2"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
2
},
{
"start_idx"
:
2
,
"end_idx"
:
4
},
],
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
3
},
],
},
),
(
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
],
{
"pattern_1"
:
[
28747
,
32000
],
"pattern_2"
:
[
28747
,
32000
,
32000
,
32000
],
"pattern_3"
:
[
28747
,
0
,
32000
],
},
{
"pattern_1"
:
[
{
"start_idx"
:
1
,
"end_idx"
:
3
},
{
"start_idx"
:
6
,
"end_idx"
:
8
},
],
"pattern_2"
:
[
{
"start_idx"
:
1
,
"end_idx"
:
5
},
],
"pattern_3"
:
[],
},
),
],
)
# yapf: enable
def
test_find_token_matches
(
prompt
,
target_by_key
,
expected_by_key
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
result
=
find_token_matches
(
prompt
,
[
PromptReplacement
(
target
,
[],
0
).
bind
(
key
,
mock_tokenizer
)
for
key
,
target
in
target_by_key
.
items
()
],
)
# Only displayed on error
print
(
"result:"
,
result
)
# Manually constructed results
result_groups
=
dict
(
full_groupby
(
result
,
key
=
lambda
x
:
x
.
modality
))
assert
{
key
:
[
dict
(
start_idx
=
item
.
start_idx
,
end_idx
=
item
.
end_idx
)
for
item
in
result_groups
.
get
(
key
,
[])
]
for
key
in
expected_by_key
}
==
expected_by_key
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"prompt"
,
"target_by_key"
,
"expected_by_key"
),
[
# Detokenized test cases of `test_find_token_matches`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
(
""
,
{
"pattern_1"
:
""
,
"pattern_2"
:
"<image>"
,
},
{
"pattern_1"
:
[{
"start_idx"
:
0
,
"end_idx"
:
0
}],
"pattern_2"
:
[],
}
),
(
"<image><image><image><image>"
,
{
"pattern_1"
:
"<image>"
,
"pattern_2"
:
"<image><image>"
,
"pattern_3"
:
"<image><image><image>"
,
},
{
"pattern_1"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
7
},
{
"start_idx"
:
7
,
"end_idx"
:
14
},
{
"start_idx"
:
14
,
"end_idx"
:
21
},
{
"start_idx"
:
21
,
"end_idx"
:
28
},
],
"pattern_2"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
14
},
{
"start_idx"
:
14
,
"end_idx"
:
28
},
],
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
21
},
],
},
),
(
"Image:<image><image><image>Image:<image><image>!"
,
{
"pattern_1"
:
"Image:<image>"
,
"pattern_2"
:
"Image:<image><image><image>"
,
"pattern_3"
:
"Image:<unk><image>"
,
},
{
"pattern_1"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
13
},
{
"start_idx"
:
27
,
"end_idx"
:
40
},
],
"pattern_2"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
27
},
],
"pattern_3"
:
[],
},
),
# Test regex escape
(
"<|image|><image><|image|><image>"
,
{
"pattern_1"
:
"<|image|>"
,
"pattern_2"
:
"<|image|><image>"
,
"pattern_3"
:
"<|image|><image><|image|>"
,
},
{
"pattern_1"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
9
},
{
"start_idx"
:
16
,
"end_idx"
:
25
},
],
"pattern_2"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
16
},
{
"start_idx"
:
16
,
"end_idx"
:
32
},
],
"pattern_3"
:
[
{
"start_idx"
:
0
,
"end_idx"
:
25
},
],
},
),
],
)
# yapf: enable
def
test_find_text_matches
(
prompt
,
target_by_key
,
expected_by_key
):
# Should not be used since there is nothing to convert to text
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
result
=
find_text_matches
(
prompt
,
[
PromptReplacement
(
target
,
[],
0
).
bind
(
key
,
mock_tokenizer
)
for
key
,
target
in
target_by_key
.
items
()
],
)
# Only displayed on error
print
(
"result:"
,
result
)
# Manually constructed results
result_groups
=
dict
(
full_groupby
(
result
,
key
=
lambda
x
:
x
.
modality
))
assert
{
key
:
[
dict
(
start_idx
=
item
.
start_idx
,
end_idx
=
item
.
end_idx
)
for
item
in
result_groups
.
get
(
key
,
[])
]
for
key
in
expected_by_key
}
==
expected_by_key
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"prompt"
,
"target_by_key"
,
"repl_by_key"
,
"expected_by_mm_count"
),
[
(
"Image:<image>Image:<image><image>!"
,
{
# We use `<image>` before `Image:` to test matches that
# occur out of order
"pattern_1"
:
"<image>"
,
"pattern_2"
:
"Image:"
,
"pattern_3"
:
"!"
,
},
{
# Test whether target is confused with repl_unit
"pattern_1"
:
(
"<image><image>"
,
1
),
# Test empty repl_unit
"pattern_2"
:
(
""
,
1
),
# Test multiple repl_count
"pattern_3"
:
(
"?"
,
2
),
},
{
# Test no replacement
0
:
"Image:<image>Image:<image><image>!"
,
# Test single replacement
1
:
"<image><image>Image:<image><image>??"
,
# Test repeated replacement
2
:
"<image><image><image><image><image>??"
,
},
),
]
)
# yapf: enable
def
test_find_replace_text
(
prompt
,
target_by_key
,
repl_by_key
,
expected_by_mm_count
,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer
=
cast
(
AnyTokenizer
,
object
())
matches
=
find_text_matches
(
prompt
,
[
PromptReplacement
(
target
,
*
repl_by_key
[
key
])
\
.
bind
(
key
,
mock_tokenizer
)
for
key
,
target
in
target_by_key
.
items
()
],
)
result_by_mm_count
=
{
mm_count
:
replace_text_matches
(
prompt
,
matches
,
{
key
:
list
(
range
(
mm_count
))
for
key
in
repl_by_key
},
BatchFeature
(),
)
for
mm_count
in
expected_by_mm_count
}
# Only displayed on error
print
(
"matches:"
,
matches
)
print
(
"result_by_mm_count:"
,
result_by_mm_count
)
# Manually constructed results
assert
result_by_mm_count
==
expected_by_mm_count
tests/multimodal/test_utils.py
View file @
c8acd805
...
@@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
...
@@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model):
2
,
2
,
"<image><image><image>"
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
],
[
32000
,
32000
,
32000
],
[{
"offset"
:
0
,
"length"
:
2
}]),
[{
"offset"
:
0
,
"length"
:
2
}],
),
(
(
"<image><image>"
,
"<image><image>"
,
[
3
,
2
],
[
3
,
2
],
...
...
vllm/multimodal/inputs.py
View file @
c8acd805
...
@@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
...
@@ -203,14 +203,7 @@ class MultiModalInputsV2(TypedDict):
"""The type of inputs."""
"""The type of inputs."""
prompt
:
str
prompt
:
str
"""
"""The processed prompt text."""
The original, unprocessed prompt text.
Note:
Since prompt text is not required by vLLM internals, we leave this
unprocessed to save CPU computation. You can still call
:code:`tokenizer.decode(prompt_token_ids)` to get the processed text.
"""
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
"""The processed token IDs which includes placeholder tokens."""
"""The processed token IDs which includes placeholder tokens."""
...
...
vllm/multimodal/processing.py
View file @
c8acd805
import
re
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
from
typing
import
(
Any
,
Callable
,
Collection
,
Generic
,
List
,
Mapping
,
from
itertools
import
groupby
Optional
,
TypedDict
,
TypeVar
,
final
)
from
typing
import
Any
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
import
numpy
as
np
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
typing_extensions
import
TypeAlias
from
typing_extensions
import
TypeAlias
,
TypedDict
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
is_list_of
from
vllm.utils
import
flatten_2d_lists
,
full_groupby
,
is_list_of
from
.inputs
import
(
AudioItem
,
ImageItem
,
MultiModalDataDict
,
from
.inputs
import
(
AudioItem
,
ImageItem
,
MultiModalDataDict
,
MultiModalInputsV2
,
MultiModalKwargs
,
PlaceholderRange
,
MultiModalInputsV2
,
MultiModalKwargs
,
PlaceholderRange
,
VideoItem
)
VideoItem
)
def
bind_prompt_sequence
(
seq
:
Union
[
str
,
list
[
int
]],
tokenizer
:
AnyTokenizer
,
)
->
"_BoundPromptSequence"
:
"""
Bind a text or token sequence to a tokenizer so that it can be
lazily converted into the other format on demand.
"""
return
_BoundPromptSequence
(
tokenizer
=
tokenizer
,
_text
=
seq
if
isinstance
(
seq
,
str
)
else
None
,
_token_ids
=
seq
if
isinstance
(
seq
,
list
)
else
None
,
)
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
_S
=
TypeVar
(
"_S"
,
str
,
list
[
int
])
ReplacementFunc
:
TypeAlias
=
Callable
[[
_T
,
BatchFeature
,
int
],
List
[
int
]]
"""
@
dataclass
Given the original data item, HF-processed data, and index of the processed
class
PromptReplacement
(
Generic
[
_S
,
_T
]):
item, output the replacement token IDs to be allocated in vLLM.
target
:
_S
"""
"""The text or token sequence to find and replace."""
repl_unit
:
_S
"""
The unit making up the replacement text or token sequence.
See :code:`repl_count` for more details.
"""
repl_count
:
Union
[
Callable
[[
list
[
_T
],
BatchFeature
,
int
],
int
],
int
]
"""
Given the original multi-modal items for this modality, HF-processed data,
and index of the processed item, output the number of repetitions of
:code:`repl_unit` to build up the replacement text or token sequence.
For convenience, you can pass in an integer if the number of repetitions is
a constant.
"""
def
__repr__
(
self
)
->
str
:
return
(
f
"
{
type
(
self
).
__name__
}
(target=
{
self
.
target
!
r
}
, "
f
"repl_unit=
{
self
.
repl_unit
!
r
}
)"
)
def
bind
(
self
,
modality
:
str
,
tokenizer
:
AnyTokenizer
,
)
->
"_BoundPromptReplacement[_T]"
:
return
_BoundPromptReplacement
(
modality
=
modality
,
target
=
bind_prompt_sequence
(
self
.
target
,
tokenizer
),
repl_unit
=
bind_prompt_sequence
(
self
.
repl_unit
,
tokenizer
),
repl_count
=
self
.
repl_count
,
)
@
dataclass
@
dataclass
class
ModalityProcessingMetadata
(
Generic
[
_T
]):
class
ModalityProcessingMetadata
(
Generic
[
_T
]):
placeholder_replacements
:
Mapping
[
str
,
ReplacementFunc
]
prompt_repls
:
Sequence
[
Union
[
PromptReplacement
[
str
,
_T
],
PromptReplacement
[
list
[
int
],
_T
]]]
"""
"""
A dictionary where each item represents the original placeholder in the
Defines each text or token sequence to replace in the HF-processed prompt.
prompt text and the corresponding replacement.
This is skipped if the HF-processed prompt is found to already contain
the replacement prompts.
"""
"""
...
@@ -52,46 +109,138 @@ Note:
...
@@ -52,46 +109,138 @@ Note:
Read more on that :ref:`here <adding_multimodal_plugin>`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
"""
MultiModalMultiData
:
TypeAlias
=
List
[
_T
]
"""
A list of data items, where the number of data items allowed
per modality is restricted by :code:`--limit-mm-per-prompt`.
"""
def
_encode
(
tokenizer
:
AnyTokenizer
,
text
:
str
,
*
,
add_special_tokens
:
bool
=
False
,
)
->
list
[
int
]:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if
isinstance
(
tokenizer
,
MistralTokenizer
):
return
tokenizer
.
tokenizer
.
encode
(
text
,
bos
=
add_special_tokens
,
eos
=
add_special_tokens
)
@
final
return
tokenizer
.
encode
(
text
,
add_special_tokens
=
add_special_tokens
)
class
MultiModalMultiDataBuiltins
(
TypedDict
,
total
=
False
):
"""Type annotations for modality types predefined by vLLM."""
image
:
MultiModalMultiData
[
ImageItem
]
"""The input images."""
video
:
MultiModalMultiData
[
VideoItem
]
@
lru_cache
(
maxsize
=
2048
)
"""The input videos."""
def
_cached_encode
(
tokenizer
:
AnyTokenizer
,
text
:
str
,
*
,
add_special_tokens
:
bool
=
False
,
)
->
list
[
int
]:
return
_encode
(
tokenizer
,
text
,
add_special_tokens
=
add_special_tokens
)
audio
:
MultiModalMultiData
[
AudioItem
]
"""The input audios."""
def
_decode
(
tokenizer
:
AnyTokenizer
,
token_ids
:
list
[
int
],
*
,
skip_special_tokens
:
bool
=
False
,
)
->
str
:
"""
Backend-agnostic equivalent of HF's
:code:`tokenizer.decode(token_ids, skip_special_tokens=...)`.
"""
return
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
skip_special_tokens
)
MultiModalMultiDataDict
:
TypeAlias
=
Mapping
[
str
,
MultiModalMultiData
[
Any
]]
"""
A dictionary containing an entry for each modality type to input.
Note:
@
lru_cache
(
maxsize
=
2048
)
This dictionary also accepts modality keys defined outside
def
_cached_decode
(
:class:`MultiModalMultiDataBuiltins` as long as a customized plugin
tokenizer
:
AnyTokenizer
,
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
token_ids
:
tuple
[
int
,
...],
Read more on that :ref:`here <adding_multimodal_plugin>`.
*
,
"""
skip_special_tokens
:
bool
=
False
,
)
->
str
:
return
_decode
(
tokenizer
,
list
(
token_ids
),
skip_special_tokens
=
skip_special_tokens
)
class
_HasModalityAttr
(
Protocol
):
modality
:
str
class
_HasModalityProp
(
Protocol
):
def
to_multi_format
(
data
:
MultiModalDataDict
)
->
MultiModalMultiDataDict
:
@
property
def
modality
(
self
)
->
str
:
...
_M
=
TypeVar
(
"_M"
,
bound
=
Union
[
_HasModalityAttr
,
_HasModalityProp
])
def
full_groupby_modality
(
values
:
Iterable
[
_M
])
->
ItemsView
[
str
,
list
[
_M
]]:
"""Convenience function to apply :func:`full_groupby` based on modality."""
return
full_groupby
(
values
,
key
=
lambda
x
:
x
.
modality
)
@
dataclass
class
_BoundPromptSequence
:
tokenizer
:
AnyTokenizer
_text
:
Optional
[
str
]
_token_ids
:
Optional
[
list
[
int
]]
def
__post_init__
(
self
)
->
None
:
if
self
.
_text
is
None
and
self
.
_token_ids
is
None
:
raise
ValueError
(
"At least one of 'text' and 'token_ids' must be "
"specified"
)
@
property
def
text
(
self
)
->
str
:
if
self
.
_text
is
None
:
assert
self
.
_token_ids
is
not
None
self
.
_text
=
_cached_decode
(
self
.
tokenizer
,
tuple
(
self
.
_token_ids
))
return
self
.
_text
@
property
def
token_ids
(
self
)
->
list
[
int
]:
if
self
.
_token_ids
is
None
:
assert
self
.
_text
is
not
None
self
.
_token_ids
=
_cached_encode
(
self
.
tokenizer
,
self
.
_text
)
return
self
.
_token_ids
def
__repr__
(
self
)
->
str
:
return
(
f
"
{
type
(
self
).
__name__
}
(_text=
{
self
.
_text
!
r
}
, "
f
"_token_ids=
{
self
.
_token_ids
!
r
}
)"
)
@
dataclass
class
_BoundPromptReplacement
(
Generic
[
_T
]):
modality
:
str
target
:
_BoundPromptSequence
repl_unit
:
_BoundPromptSequence
repl_count
:
Union
[
Callable
[[
list
[
_T
],
BatchFeature
,
int
],
int
],
int
]
def
get_count
(
self
,
mm_items
:
list
[
_T
],
hf_inputs
:
BatchFeature
,
item_idx
:
int
,
)
->
int
:
repl_count
=
self
.
repl_count
if
isinstance
(
repl_count
,
int
):
return
repl_count
return
repl_count
(
mm_items
,
hf_inputs
,
item_idx
)
def
to_multi_format
(
data
:
MultiModalDataDict
)
->
dict
[
str
,
list
[
Any
]]:
"""
"""
Convert a :class:`MultiModalDataDict` containing single data items
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
per entry.
"""
"""
multi_data
:
Mapping
[
str
,
MultiModalMultiData
[
Any
]]
=
{}
multi_data
=
dict
[
str
,
list
[
Any
]]
()
for
k
,
v
in
data
.
items
():
for
k
,
v
in
data
.
items
():
# yapf: disable
# yapf: disable
...
@@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
...
@@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict:
return
multi_data
return
multi_data
def
encode_no_special_tokens
(
class
_TokenRun
(
NamedTuple
):
tokenizer
:
AnyTokenizer
,
token_id
:
int
text
:
str
,
)
->
List
[
int
]:
start_idx
:
int
length
:
int
def
iter_token_runs
(
token_ids
:
list
[
int
])
->
Iterable
[
_TokenRun
]:
"""
"""
Backend-agnostic equivalent of HF's
Yield the starting index and length of each run of tokens that are the same.
:code:`tokenizer.encode(text, add_special_tokens=False)`.
"""
"""
if
isinstance
(
tokenizer
,
MistralTokenizer
):
start_idx
=
0
return
tokenizer
.
tokenizer
.
encode
(
text
,
bos
=
False
,
eos
=
False
)
for
token_id
,
it
in
groupby
(
token_ids
):
length
=
sum
(
1
for
_
in
it
)
yield
_TokenRun
(
token_id
=
token_id
,
start_idx
=
start_idx
,
length
=
length
)
start_idx
+=
length
class
_PlaceholderInfo
(
NamedTuple
):
modality
:
str
offset
:
int
length
:
int
def
to_range
(
self
)
->
PlaceholderRange
:
return
PlaceholderRange
(
offset
=
self
.
offset
,
length
=
self
.
length
)
def
iter_placeholders
(
prompt_repls
:
Sequence
[
_BoundPromptReplacement
[
Any
]],
token_ids
:
list
[
int
],
*
,
min_placeholder_count
:
int
,
)
->
Iterable
[
_PlaceholderInfo
]:
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
placeholder_ids_by_modality
=
{
modality
:
{
token_id
for
prompt_repl
in
repls
for
token_id
in
prompt_repl
.
repl_unit
.
token_ids
}
for
modality
,
repls
in
full_groupby_modality
(
prompt_repls
)
}
return
tokenizer
.
encode
(
text
,
add_special_tokens
=
False
)
for
run_info
in
iter_token_runs
(
token_ids
):
if
run_info
.
length
>
min_placeholder_count
:
for
(
modality
,
placeholder_ids
)
in
placeholder_ids_by_modality
.
items
():
if
run_info
.
token_id
in
placeholder_ids
:
yield
_PlaceholderInfo
(
modality
=
modality
,
offset
=
run_info
.
start_idx
,
length
=
run_info
.
length
,
)
@
lru_cache
class
_TokenMatch
(
NamedTuple
):
def
candidate_placeholders
(
start_idx
:
int
tokenizer
:
AnyTokenizer
,
end_idx
:
int
placeholder_text
:
str
,
)
->
Collection
[
List
[
int
]]:
"""Generate token ID sequences that may represent a placeholder text."""
# When the placeholder text is not mapped to a special token ID,
# it may be tokenized differently based on whether it is at the start/end
# of the string. So, we go through each combination of whether the text
# is at the start and end boundaries of the string
# Matches the placeholder when it is in the middle of the string
start_id
,
=
encode_no_special_tokens
(
tokenizer
,
"a"
)
end_id
,
=
encode_no_special_tokens
(
tokenizer
,
"b"
)
candidate_basic
=
encode_no_special_tokens
(
tokenizer
,
placeholder_text
)
start_id_
,
*
candidate_a
=
encode_no_special_tokens
(
tokenizer
,
f
"a
{
placeholder_text
}
"
,
)
assert
start_id
==
start_id_
start_id_
,
*
candidate_ab
,
end_id_
=
encode_no_special_tokens
(
tokenizer
,
f
"a
{
placeholder_text
}
b"
,
)
assert
start_id
==
start_id_
and
end_id
==
end_id_
*
candidate_b
,
end_id_
=
encode_no_special_tokens
(
def
iter_token_matches
(
tokenizer
,
token_ids
:
list
[
int
],
f
"
{
placeholder_text
}
b"
,
match_ids
:
list
[
int
],
)
)
->
Iterable
[
_TokenMatch
]:
assert
end_id
==
end_id_
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
match_len
=
len
(
match_ids
)
# Remove duplicates (need to convert to tuple to be hashable)
last_end_idx
=
0
unique_candidates
=
{
for
start_idx
in
range
(
len
(
token_ids
)
-
match_len
+
1
):
tuple
(
c
)
if
start_idx
<
last_end_idx
:
for
c
in
[
candidate_basic
,
candidate_a
,
candidate_ab
,
candidate_b
]
continue
# Exclude overlapping matches
}
# Convert back to list
end_idx
=
start_idx
+
match_len
return
[
list
(
c
)
for
c
in
unique_candidates
]
if
token_ids
[
start_idx
:
end_idx
]
==
match_ids
:
yield
_TokenMatch
(
start_idx
=
start_idx
,
end_idx
=
end_idx
)
last_end_idx
=
end_idx
def
apply_placeholders
(
class
_PromptReplacementMatch
(
ABC
,
Generic
[
_T
,
_S
]):
token_ids
:
List
[
int
],
prompt_repl
:
_BoundPromptReplacement
[
_T
]
placeholder_ids
:
List
[
int
],
get_replacement_ids
:
Callable
[[],
List
[
int
]],
@
property
)
->
Optional
[
PlaceholderRange
]:
def
modality
(
self
)
->
str
:
"""
return
self
.
prompt_repl
.
modality
Find the first occurrence of :code:`placeholder_ids`,
and replace it with the output of :code:`get_replacement_ids`.
@
property
@
abstractmethod
def
start_idx
(
self
)
->
int
:
raise
NotImplementedError
@
property
@
abstractmethod
def
end_idx
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_repl
(
self
,
mm_items
:
list
[
_T
],
hf_inputs
:
BatchFeature
,
item_idx
:
int
,
)
->
_S
:
raise
NotImplementedError
def
__repr__
(
self
)
->
str
:
return
(
f
"
{
type
(
self
).
__name__
}
(modality=
{
self
.
modality
!
r
}
, "
f
"start_idx=
{
self
.
start_idx
!
r
}
, end_idx=
{
self
.
end_idx
!
r
}
)"
)
@
dataclass
(
repr
=
False
)
class
_PromptReplacementTokenMatch
(
_PromptReplacementMatch
[
_T
,
list
[
int
]]):
prompt_repl
:
_BoundPromptReplacement
[
_T
]
match
:
_TokenMatch
@
property
def
start_idx
(
self
)
->
int
:
return
self
.
match
.
start_idx
@
property
def
end_idx
(
self
)
->
int
:
return
self
.
match
.
end_idx
def
get_repl
(
self
,
mm_items
:
list
[
_T
],
hf_inputs
:
BatchFeature
,
item_idx
:
int
,
)
->
list
[
int
]:
prompt_repl
=
self
.
prompt_repl
count
=
prompt_repl
.
get_count
(
mm_items
,
hf_inputs
,
item_idx
)
return
prompt_repl
.
repl_unit
.
token_ids
*
count
This function updates :code:`token_ids` in place.
@
dataclass
(
repr
=
False
)
class
_PromptReplacementTextMatch
(
_PromptReplacementMatch
[
_T
,
str
]):
prompt_repl
:
_BoundPromptReplacement
[
_T
]
match
:
re
.
Match
[
str
]
@
property
def
start_idx
(
self
)
->
int
:
return
self
.
match
.
start
()
@
property
def
end_idx
(
self
)
->
int
:
return
self
.
match
.
end
()
def
get_repl
(
self
,
mm_items
:
list
[
_T
],
hf_inputs
:
BatchFeature
,
item_idx
:
int
,
)
->
str
:
prompt_repl
=
self
.
prompt_repl
count
=
prompt_repl
.
get_count
(
mm_items
,
hf_inputs
,
item_idx
)
return
prompt_repl
.
repl_unit
.
text
*
count
def
find_token_matches
(
prompt
:
list
[
int
],
prompt_repls
:
Sequence
[
_BoundPromptReplacement
[
_T
]],
)
->
list
[
_PromptReplacementTokenMatch
[
_T
]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
_PromptReplacementTokenMatch
(
prompt_repl
,
match
)
for
prompt_repl
in
prompt_repls
for
match
in
iter_token_matches
(
prompt
,
prompt_repl
.
target
.
token_ids
)
]
def
find_text_matches
(
prompt
:
str
,
prompt_repls
:
Sequence
[
_BoundPromptReplacement
[
_T
]],
)
->
list
[
_PromptReplacementTextMatch
[
_T
]]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
_PromptReplacementTextMatch
(
prompt_repl
,
match
)
for
prompt_repl
in
prompt_repls
for
match
in
re
.
finditer
(
re
.
escape
(
prompt_repl
.
target
.
text
),
prompt
)
]
def
_resolve_matches
(
prompt
:
_S
,
matches
:
Sequence
[
_PromptReplacementMatch
[
_T
,
_S
]],
)
->
list
[
_PromptReplacementMatch
[
_T
,
_S
]]:
"""
Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
"""
placeholder_length
=
len
(
placeholder_ids
)
num_matches_by_idx
=
np
.
zeros
(
len
(
prompt
),
dtype
=
int
)
for
match
in
matches
:
num_matches_by_idx
[
match
.
start_idx
:
match
.
end_idx
]
+=
1
duplicate_matches_idxs
,
=
np
.
nonzero
(
num_matches_by_idx
>
1
)
if
len
(
duplicate_matches_idxs
)
>
0
:
raise
ValueError
(
"Unable to find a unique replacement "
f
"at indices=
{
duplicate_matches_idxs
}
"
f
"of prompt=
{
prompt
}
"
)
return
sorted
(
matches
,
key
=
lambda
x
:
x
.
start_idx
)
def
_replace_matches
(
prompt
:
_S
,
matches
:
Sequence
[
_PromptReplacementMatch
[
_T
,
_S
]],
mm_items_by_modality
:
Mapping
[
str
,
list
[
_T
]],
hf_inputs
:
BatchFeature
,
)
->
list
[
_S
]:
out_seqs
=
list
[
_S
]()
prev_end_idx
=
0
next_idx_by_modality
=
{
modality
:
0
for
modality
in
mm_items_by_modality
}
for
match
in
_resolve_matches
(
prompt
,
matches
):
modality
=
match
.
modality
mm_items
=
mm_items_by_modality
[
modality
]
item_idx
=
next_idx_by_modality
[
modality
]
if
item_idx
>=
len
(
mm_items
):
continue
start_idx
=
match
.
start_idx
end_idx
=
match
.
end_idx
repl_ids
=
match
.
get_repl
(
mm_items
,
hf_inputs
,
item_idx
)
out_seqs
.
append
(
prompt
[
prev_end_idx
:
start_idx
]
+
repl_ids
)
prev_end_idx
=
end_idx
next_idx_by_modality
[
modality
]
+=
1
out_seqs
.
append
(
prompt
[
prev_end_idx
:])
return
out_seqs
def
replace_token_matches
(
prompt
:
list
[
int
],
matches
:
Sequence
[
_PromptReplacementMatch
[
_T
,
list
[
int
]]],
mm_items_by_modality
:
Mapping
[
str
,
list
[
_T
]],
hf_inputs
:
BatchFeature
,
)
->
list
[
int
]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if
not
matches
:
return
prompt
token_id_seqs
=
_replace_matches
(
prompt
,
matches
,
mm_items_by_modality
,
hf_inputs
,
)
return
flatten_2d_lists
(
token_id_seqs
)
for
start_idx
in
range
(
len
(
token_ids
)
-
placeholder_length
+
1
):
if
token_ids
[
start_idx
:
placeholder_length
]
==
placeholder_ids
:
token_ids
[
start_idx
:
placeholder_length
]
=
get_replacement_ids
()
return
PlaceholderRange
(
offset
=
start_idx
,
def
replace_text_matches
(
length
=
placeholder_length
)
prompt
:
str
,
matches
:
Sequence
[
_PromptReplacementMatch
[
_T
,
str
]],
mm_items_by_modality
:
Mapping
[
str
,
list
[
_T
]],
hf_inputs
:
BatchFeature
,
)
->
str
:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if
not
matches
:
return
prompt
return
None
texts
=
_replace_matches
(
prompt
,
matches
,
mm_items_by_modality
,
hf_inputs
,
)
return
""
.
join
(
texts
)
class
MultiModalProcessor
:
class
MultiModalProcessor
:
...
@@ -212,62 +554,166 @@ class MultiModalProcessor:
...
@@ -212,62 +554,166 @@ class MultiModalProcessor:
)
->
MultiModalInputsV2
:
)
->
MultiModalInputsV2
:
return
self
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
return
self
.
apply
(
prompt
,
mm_data
,
mm_processor_kwargs
)
def
apply
(
def
_find_placeholders
(
self
,
all_prompt_repls
:
Sequence
[
_BoundPromptReplacement
[
Any
]],
new_token_ids
:
list
[
int
],
*
,
# To avoid false positives from multi-input when detecting
# whether placeholder tokens have been inserted, in case
# the target sequence is a subset of the replacement tokens
min_placeholder_count
:
int
=
16
,
)
->
list
[
_PlaceholderInfo
]:
return
list
(
iter_placeholders
(
all_prompt_repls
,
new_token_ids
,
min_placeholder_count
=
min_placeholder_count
,
))
def
_apply_hf_processor
(
self
,
self
,
prompt
:
str
,
prompt
:
str
,
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
mm_processor_kwargs
:
Mapping
[
str
,
object
],
mm_processor_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalInputsV2
:
)
->
BatchFeature
:
tokenizer
=
self
.
ctx
.
tokenizer
hf_processor
=
self
.
ctx
.
get_hf_processor
()
hf_processor
=
self
.
ctx
.
get_hf_processor
()
processed_inputs
=
hf_processor
(
return
hf_processor
(
text
=
prompt
,
# type: ignore
text
=
prompt
,
# type: ignore
**
mm_data
,
**
mm_data
,
**
mm_processor_kwargs
,
**
mm_processor_kwargs
,
)
)
new_token_ids
,
=
processed_inputs
.
pop
(
"input_ids"
).
tolist
()
mm_kwargs
=
MultiModalKwargs
(
processed_inputs
)
mm_placeholders
:
Mapping
[
str
,
List
[
PlaceholderRange
]]
=
{}
def
_bind_prompt_replacements
(
self
,
mm_data
:
MultiModalDataDict
,
)
->
list
[
_BoundPromptReplacement
[
Any
]]:
tokenizer
=
self
.
ctx
.
tokenizer
for
modality
,
orig_inputs
in
to_multi_format
(
mm_data
).
items
():
return
[
assert
isinstance
(
orig_inputs
,
list
)
prompt_repl
.
bind
(
modality
,
tokenizer
)
for
modality
,
metadata
in
self
.
metadata
.
items
()
if
modality
in
mm_data
for
prompt_repl
in
metadata
.
prompt_repls
]
metadata
=
self
.
metadata
[
modality
]
def
_apply_prompt_replacements
(
placeholder_replacements
=
metadata
.
placeholder_replacements
self
,
mm_data
:
MultiModalDataDict
,
hf_inputs
:
BatchFeature
,
token_ids
:
list
[
int
],
prompt_repls
:
Sequence
[
_BoundPromptReplacement
[
Any
]],
)
->
tuple
[
list
[
int
],
str
,
list
[
_PlaceholderInfo
]]:
tokenizer
=
self
.
ctx
.
tokenizer
modality_placeholders
:
List
[
PlaceholderRange
]
=
[]
mm_items
=
to_multi_format
(
mm_data
)
token_matches
=
find_token_matches
(
token_ids
,
prompt_repls
)
# If the search text does not represent a special token,
# it may have different token IDs in the prompt, because
# the tokens may go across the boundaries of the search text.
# ----
# e.g. when searching for "foo" in "food", if "food" itself makes
# up a token, then the token ID of "foo" will not appear at all
# ----
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
if
all
(
len
(
matches
)
>=
len
(
mm_data
[
modality
])
for
modality
,
matches
in
full_groupby_modality
(
token_matches
)
):
# yapf: disable
token_ids
=
replace_token_matches
(
token_ids
,
token_matches
,
mm_items
,
hf_inputs
,
)
text
=
_decode
(
tokenizer
,
token_ids
)
matched_repls
=
[
match
.
prompt_repl
for
match
in
token_matches
]
else
:
text
=
_decode
(
tokenizer
,
token_ids
)
text_matches
=
find_text_matches
(
text
,
prompt_repls
)
text
=
replace_text_matches
(
text
,
text_matches
,
mm_items
,
hf_inputs
,
)
token_ids
=
_encode
(
tokenizer
,
text
)
matched_repls
=
[
match
.
prompt_repl
for
match
in
text_matches
]
placeholders
=
self
.
_find_placeholders
(
matched_repls
,
token_ids
)
# Sanity check
assert
len
(
placeholders
)
==
len
(
matched_repls
),
dict
(
# Log this information for easier debugging
text
=
text
,
token_ids
=
token_ids
,
placeholders
=
placeholders
,
matched_repls
=
matched_repls
,
)
for
item_idx
,
orig_item
in
enumerate
(
orig_inputs
):
return
token_ids
,
text
,
placeholders
for
match_text
,
replace_fn
in
placeholder_replacements
.
items
():
candidates
=
candidate_placeholders
(
tokenizer
,
match_text
)
get_replacement_ids
=
partial
(
replace_fn
,
orig_item
,
processed_inputs
,
item_idx
,
)
for
match_ids
in
candidates
:
def
apply
(
# TODO(youkaichao): Don't update new_token_ids
self
,
placeholders
=
apply_placeholders
(
prompt_text
:
str
,
new_token_ids
,
mm_data
:
MultiModalDataDict
,
match_ids
,
mm_processor_kwargs
:
Mapping
[
str
,
object
],
get_replacement_ids
,
)
->
MultiModalInputsV2
:
)
"""
Process multi-modal inputs to be used in vLLM.
The main steps are:
1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and replace sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
processed token IDs.
"""
tokenizer
=
self
.
ctx
.
tokenizer
hf_inputs
=
self
.
_apply_hf_processor
(
prompt_text
,
mm_data
,
mm_processor_kwargs
)
prompt_ids
,
=
hf_inputs
.
pop
(
"input_ids"
).
tolist
()
mm_kwargs
=
MultiModalKwargs
(
hf_inputs
)
if
placeholders
is
not
None
:
all_prompt_repls
=
self
.
_bind_prompt_replacements
(
mm_data
)
modality_placeholders
.
append
(
placeholders
)
# yapf: disable
# If HF processor already inserts placeholder tokens,
mm_placeholders
[
modality
]
=
modality_placeholders
# type: ignore[index]
# there is no need for us to insert them
# yapf: enable
all_placeholders
=
self
.
_find_placeholders
(
all_prompt_repls
,
prompt_ids
)
if
all_placeholders
:
prompt_text
=
_decode
(
tokenizer
,
prompt_ids
)
else
:
(
prompt_ids
,
prompt_text
,
all_placeholders
,
)
=
self
.
_apply_prompt_replacements
(
mm_data
,
hf_inputs
,
prompt_ids
,
all_prompt_repls
,
)
mm_placeholders
=
{
modality
:
[
item
.
to_range
()
for
item
in
items
]
for
modality
,
items
in
full_groupby_modality
(
all_placeholders
)
}
return
MultiModalInputsV2
(
return
MultiModalInputsV2
(
type
=
"multimodal"
,
type
=
"multimodal"
,
prompt
=
prompt
,
prompt
=
prompt
_text
,
prompt_token_ids
=
new_token
_ids
,
prompt_token_ids
=
prompt
_ids
,
mm_kwargs
=
mm_kwargs
,
mm_kwargs
=
mm_kwargs
,
mm_placeholders
=
mm_placeholders
,
mm_placeholders
=
mm_placeholders
,
)
)
vllm/utils.py
View file @
c8acd805
...
@@ -19,7 +19,8 @@ import uuid
...
@@ -19,7 +19,8 @@ import uuid
import
warnings
import
warnings
import
weakref
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Future
,
Task
from
collections.abc
import
Mapping
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
functools
import
lru_cache
,
partial
,
wraps
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
platform
import
uname
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
...
@@ -905,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
...
@@ -905,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
non-contiguous data.
"""
groups
=
defaultdict
[
_K
,
list
[
_V
]](
list
)
for
value
in
values
:
groups
[
key
(
value
)].
append
(
value
)
return
groups
.
items
()
# TODO: This function can be removed if transformer_modules classes are
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
# serialized by value when communicating between processes
def
init_cached_hf_modules
()
->
None
:
def
init_cached_hf_modules
()
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment