Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a0596bc
Unverified
Commit
2a0596bc
authored
Jan 08, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2025
Browse files
[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f1214117
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
263 additions
and
181 deletions
+263
-181
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+85
-134
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+117
-35
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+61
-12
No files found.
vllm/multimodal/processing.py
View file @
2a0596bc
...
...
@@ -4,12 +4,13 @@ from collections import defaultdict
from
collections.abc
import
Callable
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
Any
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
)
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
vllm
import
envs
from
vllm.inputs
import
DummyData
,
InputProcessingContext
import
vllm.envs
as
envs
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
decode_tokens
,
encode_tokens
)
...
...
@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
.profiling
import
BaseProfilingInfo
if
TYPE_CHECKING
:
from
.profiling
import
BaseDummyInputsBuilder
logger
=
init_logger
(
__name__
)
...
...
@@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input.
"""
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"
_
BoundPromptReplacement"
:
return
_
BoundPromptReplacement
(
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptReplacement"
:
return
BoundPromptReplacement
(
tokenizer
=
tokenizer
,
modality
=
self
.
modality
,
_target
=
self
.
target
,
...
...
@@ -128,7 +131,7 @@ class _BoundPromptSequence:
@
dataclass
class
_
BoundPromptReplacement
:
class
BoundPromptReplacement
:
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
modality
:
str
...
...
@@ -207,7 +210,7 @@ def iter_token_matches(
@
dataclass
(
repr
=
False
)
class
_PromptReplacementMatch
(
ABC
):
prompt_repl
:
_
BoundPromptReplacement
prompt_repl
:
BoundPromptReplacement
@
property
def
modality
(
self
)
->
str
:
...
...
@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@
dataclass
class
_
PlaceholderInfo
:
class
PlaceholderInfo
:
modality
:
str
item_idx
:
int
start_idx
:
int
...
...
@@ -274,7 +277,7 @@ class _PlaceholderInfo:
def
find_token_matches
(
prompt
:
list
[
int
],
prompt_repls
:
Sequence
[
_
BoundPromptReplacement
],
prompt_repls
:
Sequence
[
BoundPromptReplacement
],
)
->
list
[
_PromptReplacementTokenMatch
]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
...
...
@@ -286,7 +289,7 @@ def find_token_matches(
def
find_text_matches
(
prompt
:
str
,
prompt_repls
:
Sequence
[
_
BoundPromptReplacement
],
prompt_repls
:
Sequence
[
BoundPromptReplacement
],
)
->
list
[
_PromptReplacementTextMatch
]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
...
...
@@ -390,9 +393,9 @@ def replace_text_matches(
def
_iter_modality_placeholders
(
prompt
:
list
[
int
],
modality
:
str
,
modality_repls
:
Sequence
[
_
BoundPromptReplacement
],
modality_repls
:
Sequence
[
BoundPromptReplacement
],
modal_item_count
:
int
,
)
->
Iterable
[
_
PlaceholderInfo
]:
)
->
Iterable
[
PlaceholderInfo
]:
if
modal_item_count
==
0
:
return
...
...
@@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue
if
prompt
[
start_idx
:
end_idx
]
==
repl_tokens
:
yield
_
PlaceholderInfo
(
yield
PlaceholderInfo
(
modality
=
modality
,
item_idx
=
item_idx
,
start_idx
=
start_idx
,
...
...
@@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def
_iter_placeholders
(
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Iterable
[
_
PlaceholderInfo
]:
)
->
Iterable
[
PlaceholderInfo
]:
"""
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
...
...
@@ -455,10 +458,10 @@ def _iter_placeholders(
def
find_mm_placeholders
(
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderInfo
]]:
it
=
_iter_placeholders
(
mm_prompt_repls
,
prompt
,
mm_item_counts
)
return
dict
(
full_groupby_modality
(
it
))
...
...
@@ -524,29 +527,59 @@ class ProcessingCache:
self
.
_cache
.
put
(
cache_key
,
output_kwargs
)
class
ProcessingMixin
:
"""
Contains helper functions to perform processing.
class
BaseProcessingInfo
:
"""Base class containing information to perform processing."""
Not to be confused with :class:`transformers.ProcessorMixin`.
"""
ctx
:
InputProcessingContext
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
()
def
_get_tokenizer
(
self
)
->
AnyTokenizer
:
self
.
ctx
=
ctx
@
property
def
model_id
(
self
)
->
str
:
return
self
.
ctx
.
model_config
.
model
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
ctx
.
tokenizer
def
_
get_hf_config
(
self
)
->
PretrainedConfig
:
def
get_hf_config
(
self
)
->
PretrainedConfig
:
return
self
.
ctx
.
get_hf_config
()
def
_
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
"""
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
"""
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise
NotImplementedError
@
abstractmethod
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise
NotImplementedError
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
class
BaseMultiModalProcessor
(
ProcessingMixin
,
ABC
):
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
...
...
@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
def
__init__
(
self
,
ctx
:
InputProcessingContext
,
info
:
_I
,
dummy_inputs
:
"BaseDummyInputsBuilder[_I]"
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
super
().
__init__
()
self
.
ctx
=
ctx
self
.
info
=
info
self
.
dummy_inputs
=
dummy_inputs
self
.
cache
=
cache
self
.
enable_sanity_checks
=
enable_sanity_checks
self
.
data_parser
=
self
.
_get_data_parser
()
self
.
profiling_info
=
self
.
_get_profiling_info
()
def
__call__
(
self
,
...
...
@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
return
MultiModalDataParser
()
def
_get_profiling_info
(
self
)
->
BaseProfilingInfo
:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise
NotImplementedError
def
_to_mm_items
(
self
,
mm_data
:
MultiModalDataDict
,
...
...
@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
mm_items
=
self
.
data_parser
.
parse_mm_data
(
mm_data
)
mm_limits
=
self
.
ctx
.
get_mm_config
().
limit_per_prompt
mm_limits
=
self
.
info
.
ctx
.
get_mm_config
().
limit_per_prompt
for
modality
,
items
in
mm_items
.
items
():
limit
=
mm_limits
.
get
(
modality
,
1
)
if
len
(
items
)
>
limit
:
...
...
@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_find_mm_placeholders
(
self
,
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
new_token_ids
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderInfo
]]:
return
find_mm_placeholders
(
mm_prompt_repls
,
new_token_ids
,
mm_item_counts
)
def
_get_hf_mm_data
(
self
,
mm_items
:
MultiModalDataItems
,
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
Any
]]:
processor_data
=
dict
[
str
,
Any
]()
passthrough_data
=
dict
[
str
,
Any
]()
)
->
tuple
[
Mapping
[
str
,
object
],
Mapping
[
str
,
object
]]:
processor_data
=
dict
[
str
,
object
]()
passthrough_data
=
dict
[
str
,
object
]()
for
items
in
mm_items
.
values
():
processor_data
.
update
(
items
.
get_processor_data
())
...
...
@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and
associated multi-modal data.
"""
return
self
.
ctx
.
call_hf_processor
(
self
.
_
get_hf_processor
(
**
mm_kwargs
),
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
mm_kwargs
,
)
...
...
@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
dummy_inputs
=
self
.
profiling_info
.
get_dummy_processor_inputs
(
self
.
ctx
.
model_config
.
max_model_len
,
dummy_inputs
=
self
.
dummy_inputs
.
get_dummy_processor_inputs
(
self
.
info
.
ctx
.
model_config
.
max_model_len
,
mm_missing_counts
,
)
...
...
@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results.
"""
cache
=
self
.
cache
model_id
=
self
.
ctx
.
model_
config
.
model
model_id
=
self
.
info
.
model_
id
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
if
cache
is
None
or
passthrough_data
:
...
...
@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_bind_and_group_repls
(
self
,
prompt_repls
:
list
[
PromptReplacement
],
)
->
dict
[
str
,
list
[
_
BoundPromptReplacement
]]:
tokenizer
=
self
.
_
get_tokenizer
()
)
->
dict
[
str
,
list
[
BoundPromptReplacement
]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_repls
)
return
dict
(
full_groupby_modality
(
it
))
...
...
@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_apply_prompt_replacements
(
self
,
token_ids
:
list
[
int
],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]]:
tokenizer
=
self
.
_
get_tokenizer
()
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderInfo
]]]:
tokenizer
=
self
.
info
.
get_tokenizer
()
mm_token_matches
=
{
modality
:
find_token_matches
(
token_ids
,
prompt_repls
)
...
...
@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_validate_mm_placeholders
(
self
,
mm_placeholders
:
Mapping
[
str
,
list
[
_
PlaceholderInfo
]],
mm_placeholders
:
Mapping
[
str
,
list
[
PlaceholderInfo
]],
mm_item_counts
:
Mapping
[
str
,
int
],
*
,
allow_missing
:
bool
=
False
,
...
...
@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing.
if
envs
.
VLLM_USE_V1
:
model_id
=
self
.
ctx
.
model_
config
.
model
model_id
=
self
.
info
.
model_
id
mm_hashes
=
{
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
...
...
@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing
=
True
,
)
mm_missing_repls
=
dict
[
str
,
list
[
_
BoundPromptReplacement
]]()
mm_missing_repls
=
dict
[
str
,
list
[
BoundPromptReplacement
]]()
for
modality
,
missing_repl_count
in
mm_missing_repl_counts
.
items
():
if
missing_repl_count
==
0
:
mm_missing_repls
[
modality
]
=
[]
...
...
@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
if
all
(
len
(
repls
)
==
0
for
repls
in
mm_missing_repls
.
items
()):
tokenizer
=
self
.
_
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
prompt_text
=
decode_tokens
(
tokenizer
,
prompt_ids
)
mm_placeholders
=
hf_mm_placeholders
else
:
...
...
@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes
=
mm_hashes
,
mm_placeholders
=
mm_placeholder_ranges
,
)
def
_get_dummy_mm_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalInputsV2
:
profiling
=
self
.
profiling_info
processor_inputs
=
profiling
.
get_dummy_processor_inputs
(
seq_len
,
mm_counts
)
return
self
.
apply
(
prompt_text
=
processor_inputs
.
prompt_text
,
mm_data
=
processor_inputs
.
mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
profiling
=
self
.
profiling_info
mm_counts
=
profiling
.
get_mm_limits
()
mm_max_tokens_per_item
=
profiling
.
get_mm_max_tokens_per_item
(
seq_len
)
if
mm_counts
.
keys
()
!=
mm_max_tokens_per_item
.
keys
():
raise
AssertionError
(
"The keys returned by `get_supported_mm_limits`"
f
"(
{
set
(
mm_counts
.
keys
())
}
) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
expected_placeholders_by_modality
=
{
modality
:
mm_max_tokens_per_item
[
modality
]
*
mm_counts
[
modality
]
for
modality
in
placeholders_by_modality
}
if
total_placeholders_by_modality
!=
expected_placeholders_by_modality
:
raise
AssertionError
(
f
"The processed dummy data has a total of "
f
"
{
total_placeholders_by_modality
}
placeholder tokens, which "
f
"is not the expected
{
expected_placeholders_by_modality
}
"
"tokens."
)
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
len
(
prompt_token_ids
)))
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
),
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_placeholders
=
placeholders_by_modality
,
)
vllm/multimodal/profiling.py
View file @
2a0596bc
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Generic
,
TypeVar
import
numpy
as
np
import
numpy.typing
as
npt
from
PIL
import
Image
from
vllm.inputs
import
InputProcessingContext
import
vllm.envs
as
envs
from
vllm.inputs
import
DummyData
from
vllm.logger
import
init_logger
from
.inputs
import
MultiModalDataDict
from
.inputs
import
MultiModalDataDict
,
MultiModalInputsV2
from
.processing
import
BaseMultiModalProcessor
,
BaseProcessingInfo
logger
=
init_logger
(
__name__
)
...
...
@@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
]
=
field
(
default_factory
=
dict
)
class
BaseProfilingInfo
(
ABC
):
"""
Abstract base class that provides the information necessary to profile
multi-modal models.
"""
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
()
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
self
.
ctx
=
ctx
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
class
BaseDummyInputsBuilder
(
ABC
,
Generic
[
_I
]):
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
Abstract base class that constructs the dummy data to profile
multi-modal models.
"""
raise
NotImplementedError
@
abstractmethod
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum possible number of tokens per data item
for each modality.
def
__init__
(
self
,
info
:
_I
)
->
None
:
super
().
__init__
()
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise
NotImplementedError
self
.
info
=
info
@
abstractmethod
def
get_dummy_processor_inputs
(
...
...
@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
"""
Build the
multi-modal portion of the
input which, after processing,
results in `mm_max_tokens` in :meth:`
get_mm_max_tokens_per_item
`
.
Build the input which, after processing,
results in
`self.info.
get_mm_max_tokens_per_item
()` placeholder tokens
.
"""
raise
NotImplementedError
...
...
@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video
=
np
.
zeros
((
num_frames
,
width
,
height
,
3
))
return
[
video
]
*
num_videos
def
get_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
mm_config
=
self
.
ctx
.
get_mm_config
()
class
MultiModalProfiler
(
Generic
[
_I
]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def
__init__
(
self
,
processor
:
BaseMultiModalProcessor
[
_I
],
)
->
None
:
super
().
__init__
()
self
.
processor
=
processor
@
property
def
processing_info
(
self
)
->
BaseProcessingInfo
:
return
self
.
processor
.
info
@
property
def
dummy_inputs
(
self
)
->
BaseDummyInputsBuilder
[
_I
]:
return
self
.
processor
.
dummy_inputs
def
_get_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
mm_config
=
self
.
processing_info
.
ctx
.
get_mm_config
()
mm_limit_per_prompt
=
mm_config
.
limit_per_prompt
supported_mm_limits
=
self
.
get_supported_mm_limits
()
supported_mm_limits
=
self
.
processing_info
.
get_supported_mm_limits
()
mm_limits
=
{
modality
:
mm_limit_per_prompt
.
get
(
modality
,
1
)
...
...
@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f
"at most
{
supported_limit
}
{
modality
}
items."
)
return
mm_limits
def
_get_dummy_mm_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalInputsV2
:
factory
=
self
.
dummy_inputs
processor_inputs
=
factory
.
get_dummy_processor_inputs
(
seq_len
,
mm_counts
)
return
self
.
processor
.
apply
(
prompt_text
=
processor_inputs
.
prompt_text
,
mm_data
=
processor_inputs
.
mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
mm_counts
=
self
.
_get_mm_limits
()
info
=
self
.
processing_info
mm_max_tokens_per_item
=
info
.
get_mm_max_tokens_per_item
(
seq_len
)
if
mm_counts
.
keys
()
!=
mm_max_tokens_per_item
.
keys
():
raise
AssertionError
(
"The keys returned by `get_supported_mm_limits`"
f
"(
{
set
(
mm_counts
.
keys
())
}
) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
expected_placeholders_by_modality
=
{
modality
:
mm_max_tokens_per_item
[
modality
]
*
mm_counts
[
modality
]
for
modality
in
placeholders_by_modality
}
if
total_placeholders_by_modality
!=
expected_placeholders_by_modality
:
raise
AssertionError
(
f
"The processed dummy data has a total of "
f
"
{
total_placeholders_by_modality
}
placeholder tokens, which "
f
"is not the expected
{
expected_placeholders_by_modality
}
"
"tokens."
)
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
len
(
prompt_token_ids
)))
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
),
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_placeholders
=
placeholders_by_modality
,
)
vllm/multimodal/registry.py
View file @
2a0596bc
import
functools
from
collections
import
UserDict
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Mapping
,
Optional
,
Protocol
,
Sequence
,
Type
,
TypeVar
)
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
Mapping
,
Optional
,
Protocol
,
Sequence
,
Type
,
TypeVar
)
import
torch.nn
as
nn
...
...
@@ -14,7 +15,9 @@ from .audio import AudioPlugin
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
from
.image
import
ImagePlugin
from
.inputs
import
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
from
.processing
import
BaseMultiModalProcessor
,
ProcessingCache
from
.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
)
from
.profiling
import
BaseDummyInputsBuilder
from
.utils
import
cached_get_tokenizer
from
.video
import
VideoPlugin
...
...
@@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE
=
256
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
_I_co
=
TypeVar
(
"_I_co"
,
bound
=
BaseProcessingInfo
,
covariant
=
True
)
class
MultiModalProcessor
Factory
(
Protocol
):
class
ProcessingInfo
Factory
(
Protocol
[
_I_co
]
):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def
__call__
(
self
,
ctx
:
InputProcessingContext
,
)
->
_I_co
:
...
class
DummyInputsBuilderFactory
(
Protocol
[
_I
]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def
__call__
(
self
,
info
:
_I
)
->
BaseDummyInputsBuilder
[
_I
]:
...
class
MultiModalProcessorFactory
(
Protocol
[
_I
]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def
__call__
(
self
,
info
:
_I
,
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
[
_I
]
:
...
@
dataclass
(
frozen
=
True
)
class
_ProcessorFactories
(
Generic
[
_I
]):
info
:
ProcessingInfoFactory
[
_I
]
processor
:
MultiModalProcessorFactory
[
_I
]
dummy_inputs
:
DummyInputsBuilderFactory
[
_I
]
def
build_processor
(
self
,
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
):
info
=
self
.
info
(
ctx
)
dummy_inputs_builder
=
self
.
dummy_inputs
(
info
)
return
self
.
processor
(
info
,
dummy_inputs_builder
,
cache
=
cache
)
class
_MultiModalLimits
(
UserDict
[
"ModelConfig"
,
Dict
[
str
,
int
]]):
"""
Wraps `_limits_by_model` for a more informative error message
...
...
@@ -71,7 +113,7 @@ class MultiModalRegistry:
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
MultiModal
ProcessorFactor
y
]()
_
ProcessorFactor
ies
]()
# This is used for non-multimodal models
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
...
...
@@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
processor
=
self
.
create_processor
(
model_config
,
tokenizer
)
seq_len
=
model_config
.
max_model_len
return
processor
.
profiling_
info
.
get_mm_max_tokens_per_item
(
seq_len
)
return
processor
.
info
.
get_mm_max_tokens_per_item
(
seq_len
)
return
{
key
:
plugin
.
get_max_multimodal_tokens
(
model_config
)
...
...
@@ -315,7 +357,10 @@ class MultiModalRegistry:
def
register_processor
(
self
,
factory
:
MultiModalProcessorFactory
,
processor
:
MultiModalProcessorFactory
[
_I
],
*
,
info
:
ProcessingInfoFactory
[
_I
],
dummy_inputs
:
DummyInputsBuilderFactory
[
_I
],
):
"""
Register a multi-modal processor to a model class. The processor
...
...
@@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_processor_factories
[
model_cls
]
=
factory
self
.
_processor_factories
[
model_cls
]
=
_ProcessorFactories
(
info
=
info
,
dummy_inputs
=
dummy_inputs
,
processor
=
processor
,
)
return
model_cls
...
...
@@ -359,15 +408,15 @@ class MultiModalRegistry:
self
,
model_config
:
"ModelConfig"
,
tokenizer
:
AnyTokenizer
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
[
BaseProcessingInfo
]
:
"""
Create a multi-modal processor for a specific model and tokenizer.
"""
model_cls
=
self
.
_get_model_cls
(
model_config
)
processor_
factor
y
=
self
.
_processor_factories
[
model_cls
]
factor
ies
=
self
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
)
cache
=
(
None
if
model_config
.
disable_mm_preprocessor_cache
else
self
.
_processing_cache
)
return
processor
_factory
(
ctx
,
cache
=
cache
)
return
factories
.
build_
processor
(
ctx
,
cache
=
cache
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment