Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2a0596bc
Unverified
Commit
2a0596bc
authored
Jan 08, 2025
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2025
Browse files
[VLM] Reorganize profiling/processing-related code (#11812)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
f1214117
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
263 additions
and
181 deletions
+263
-181
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+85
-134
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+117
-35
vllm/multimodal/registry.py
vllm/multimodal/registry.py
+61
-12
No files found.
vllm/multimodal/processing.py
View file @
2a0596bc
...
@@ -4,12 +4,13 @@ from collections import defaultdict
...
@@ -4,12 +4,13 @@ from collections import defaultdict
from
collections.abc
import
Callable
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Callable
,
ItemsView
,
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Any
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
from
typing
import
(
TYPE_CHECKING
,
Generic
,
NamedTuple
,
Optional
,
Protocol
,
TypeVar
,
Union
)
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
vllm
import
envs
import
vllm.envs
as
envs
from
vllm.inputs
import
DummyData
,
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
decode_tokens
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
decode_tokens
,
encode_tokens
)
encode_tokens
)
...
@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
...
@@ -20,7 +21,9 @@ from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2
,
MultiModalKwargs
,
MultiModalInputsV2
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
.parse
import
MultiModalDataItems
,
MultiModalDataParser
from
.profiling
import
BaseProfilingInfo
if
TYPE_CHECKING
:
from
.profiling
import
BaseDummyInputsBuilder
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -46,8 +49,8 @@ class PromptReplacement:
...
@@ -46,8 +49,8 @@ class PromptReplacement:
if it does not depend on the input.
if it does not depend on the input.
"""
"""
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"
_
BoundPromptReplacement"
:
def
bind
(
self
,
tokenizer
:
AnyTokenizer
)
->
"BoundPromptReplacement"
:
return
_
BoundPromptReplacement
(
return
BoundPromptReplacement
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
modality
=
self
.
modality
,
modality
=
self
.
modality
,
_target
=
self
.
target
,
_target
=
self
.
target
,
...
@@ -128,7 +131,7 @@ class _BoundPromptSequence:
...
@@ -128,7 +131,7 @@ class _BoundPromptSequence:
@
dataclass
@
dataclass
class
_
BoundPromptReplacement
:
class
BoundPromptReplacement
:
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
tokenizer
:
AnyTokenizer
=
field
(
repr
=
False
)
modality
:
str
modality
:
str
...
@@ -207,7 +210,7 @@ def iter_token_matches(
...
@@ -207,7 +210,7 @@ def iter_token_matches(
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_PromptReplacementMatch
(
ABC
):
class
_PromptReplacementMatch
(
ABC
):
prompt_repl
:
_
BoundPromptReplacement
prompt_repl
:
BoundPromptReplacement
@
property
@
property
def
modality
(
self
)
->
str
:
def
modality
(
self
)
->
str
:
...
@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
...
@@ -255,7 +258,7 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch):
@
dataclass
@
dataclass
class
_
PlaceholderInfo
:
class
PlaceholderInfo
:
modality
:
str
modality
:
str
item_idx
:
int
item_idx
:
int
start_idx
:
int
start_idx
:
int
...
@@ -274,7 +277,7 @@ class _PlaceholderInfo:
...
@@ -274,7 +277,7 @@ class _PlaceholderInfo:
def
find_token_matches
(
def
find_token_matches
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
prompt_repls
:
Sequence
[
_
BoundPromptReplacement
],
prompt_repls
:
Sequence
[
BoundPromptReplacement
],
)
->
list
[
_PromptReplacementTokenMatch
]:
)
->
list
[
_PromptReplacementTokenMatch
]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
return
[
...
@@ -286,7 +289,7 @@ def find_token_matches(
...
@@ -286,7 +289,7 @@ def find_token_matches(
def
find_text_matches
(
def
find_text_matches
(
prompt
:
str
,
prompt
:
str
,
prompt_repls
:
Sequence
[
_
BoundPromptReplacement
],
prompt_repls
:
Sequence
[
BoundPromptReplacement
],
)
->
list
[
_PromptReplacementTextMatch
]:
)
->
list
[
_PromptReplacementTextMatch
]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return
[
return
[
...
@@ -390,9 +393,9 @@ def replace_text_matches(
...
@@ -390,9 +393,9 @@ def replace_text_matches(
def
_iter_modality_placeholders
(
def
_iter_modality_placeholders
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
modality
:
str
,
modality
:
str
,
modality_repls
:
Sequence
[
_
BoundPromptReplacement
],
modality_repls
:
Sequence
[
BoundPromptReplacement
],
modal_item_count
:
int
,
modal_item_count
:
int
,
)
->
Iterable
[
_
PlaceholderInfo
]:
)
->
Iterable
[
PlaceholderInfo
]:
if
modal_item_count
==
0
:
if
modal_item_count
==
0
:
return
return
...
@@ -413,7 +416,7 @@ def _iter_modality_placeholders(
...
@@ -413,7 +416,7 @@ def _iter_modality_placeholders(
continue
continue
if
prompt
[
start_idx
:
end_idx
]
==
repl_tokens
:
if
prompt
[
start_idx
:
end_idx
]
==
repl_tokens
:
yield
_
PlaceholderInfo
(
yield
PlaceholderInfo
(
modality
=
modality
,
modality
=
modality
,
item_idx
=
item_idx
,
item_idx
=
item_idx
,
start_idx
=
start_idx
,
start_idx
=
start_idx
,
...
@@ -434,10 +437,10 @@ def _iter_modality_placeholders(
...
@@ -434,10 +437,10 @@ def _iter_modality_placeholders(
def
_iter_placeholders
(
def
_iter_placeholders
(
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Iterable
[
_
PlaceholderInfo
]:
)
->
Iterable
[
PlaceholderInfo
]:
"""
"""
For each modality, yield each set of placeholder tokens found in
For each modality, yield each set of placeholder tokens found in
:code:`prompt`.
:code:`prompt`.
...
@@ -455,10 +458,10 @@ def _iter_placeholders(
...
@@ -455,10 +458,10 @@ def _iter_placeholders(
def
find_mm_placeholders
(
def
find_mm_placeholders
(
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderInfo
]]:
it
=
_iter_placeholders
(
mm_prompt_repls
,
prompt
,
mm_item_counts
)
it
=
_iter_placeholders
(
mm_prompt_repls
,
prompt
,
mm_item_counts
)
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
...
@@ -524,29 +527,59 @@ class ProcessingCache:
...
@@ -524,29 +527,59 @@ class ProcessingCache:
self
.
_cache
.
put
(
cache_key
,
output_kwargs
)
self
.
_cache
.
put
(
cache_key
,
output_kwargs
)
class
ProcessingMixin
:
class
BaseProcessingInfo
:
"""
"""Base class containing information to perform processing."""
Contains helper functions to perform processing.
Not to be confused with :class:`transformers.ProcessorMixin`.
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
"""
super
().
__init__
()
ctx
:
InputProcessingContext
def
_get_tokenizer
(
self
)
->
AnyTokenizer
:
self
.
ctx
=
ctx
@
property
def
model_id
(
self
)
->
str
:
return
self
.
ctx
.
model_config
.
model
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
ctx
.
tokenizer
return
self
.
ctx
.
tokenizer
def
_
get_hf_config
(
self
)
->
PretrainedConfig
:
def
get_hf_config
(
self
)
->
PretrainedConfig
:
return
self
.
ctx
.
get_hf_config
()
return
self
.
ctx
.
get_hf_config
()
def
_
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
ProcessorMixin
:
"""
"""
Subclasses can override this method to handle
Subclasses can override this method to handle
specific kwargs from model config or user inputs.
specific kwargs from model config or user inputs.
"""
"""
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
**
kwargs
)
@
abstractmethod
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
"""
Return the maximum supported number of items for each modality.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
raise
NotImplementedError
@
abstractmethod
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise
NotImplementedError
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
class
BaseMultiModalProcessor
(
ProcessingMixin
,
ABC
):
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
"""
"""
Abstract base class to process multi-modal inputs to be used in vLLM.
Abstract base class to process multi-modal inputs to be used in vLLM.
...
@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -554,18 +587,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
ctx
:
InputProcessingContext
,
info
:
_I
,
dummy_inputs
:
"BaseDummyInputsBuilder[_I]"
,
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
enable_sanity_checks
:
bool
=
True
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
ctx
=
ctx
self
.
info
=
info
self
.
dummy_inputs
=
dummy_inputs
self
.
cache
=
cache
self
.
cache
=
cache
self
.
enable_sanity_checks
=
enable_sanity_checks
self
.
enable_sanity_checks
=
enable_sanity_checks
self
.
data_parser
=
self
.
_get_data_parser
()
self
.
data_parser
=
self
.
_get_data_parser
()
self
.
profiling_info
=
self
.
_get_profiling_info
()
def
__call__
(
def
__call__
(
self
,
self
,
...
@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -585,13 +619,6 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
"""
return
MultiModalDataParser
()
return
MultiModalDataParser
()
def
_get_profiling_info
(
self
)
->
BaseProfilingInfo
:
"""
Get the profiling information to find the worst-case memory usage of
the model.
"""
raise
NotImplementedError
def
_to_mm_items
(
def
_to_mm_items
(
self
,
self
,
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
...
@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -602,7 +629,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
"""
"""
mm_items
=
self
.
data_parser
.
parse_mm_data
(
mm_data
)
mm_items
=
self
.
data_parser
.
parse_mm_data
(
mm_data
)
mm_limits
=
self
.
ctx
.
get_mm_config
().
limit_per_prompt
mm_limits
=
self
.
info
.
ctx
.
get_mm_config
().
limit_per_prompt
for
modality
,
items
in
mm_items
.
items
():
for
modality
,
items
in
mm_items
.
items
():
limit
=
mm_limits
.
get
(
modality
,
1
)
limit
=
mm_limits
.
get
(
modality
,
1
)
if
len
(
items
)
>
limit
:
if
len
(
items
)
>
limit
:
...
@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -646,19 +673,19 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_find_mm_placeholders
(
def
_find_mm_placeholders
(
self
,
self
,
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
new_token_ids
:
list
[
int
],
new_token_ids
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]:
)
->
Mapping
[
str
,
list
[
PlaceholderInfo
]]:
return
find_mm_placeholders
(
mm_prompt_repls
,
new_token_ids
,
return
find_mm_placeholders
(
mm_prompt_repls
,
new_token_ids
,
mm_item_counts
)
mm_item_counts
)
def
_get_hf_mm_data
(
def
_get_hf_mm_data
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
Any
]]:
)
->
tuple
[
Mapping
[
str
,
object
],
Mapping
[
str
,
object
]]:
processor_data
=
dict
[
str
,
Any
]()
processor_data
=
dict
[
str
,
object
]()
passthrough_data
=
dict
[
str
,
Any
]()
passthrough_data
=
dict
[
str
,
object
]()
for
items
in
mm_items
.
values
():
for
items
in
mm_items
.
values
():
processor_data
.
update
(
items
.
get_processor_data
())
processor_data
.
update
(
items
.
get_processor_data
())
...
@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -678,8 +705,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
Call the HF processor on the prompt text and
Call the HF processor on the prompt text and
associated multi-modal data.
associated multi-modal data.
"""
"""
return
self
.
ctx
.
call_hf_processor
(
return
self
.
info
.
ctx
.
call_hf_processor
(
self
.
_
get_hf_processor
(
**
mm_kwargs
),
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
dict
(
text
=
prompt
,
**
mm_data
),
dict
(
text
=
prompt
,
**
mm_data
),
mm_kwargs
,
mm_kwargs
,
)
)
...
@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -738,8 +765,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# Some HF processors (e.g. Qwen2-VL) expect corresponding
# multi-modal tokens to be in the prompt text
# multi-modal tokens to be in the prompt text
dummy_inputs
=
self
.
profiling_info
.
get_dummy_processor_inputs
(
dummy_inputs
=
self
.
dummy_inputs
.
get_dummy_processor_inputs
(
self
.
ctx
.
model_config
.
max_model_len
,
self
.
info
.
ctx
.
model_config
.
max_model_len
,
mm_missing_counts
,
mm_missing_counts
,
)
)
...
@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -762,7 +789,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
caching the results and reusing cached results.
caching the results and reusing cached results.
"""
"""
cache
=
self
.
cache
cache
=
self
.
cache
model_id
=
self
.
ctx
.
model_
config
.
model
model_id
=
self
.
info
.
model_
id
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
if
cache
is
None
or
passthrough_data
:
if
cache
is
None
or
passthrough_data
:
...
@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -838,8 +865,8 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_bind_and_group_repls
(
def
_bind_and_group_repls
(
self
,
self
,
prompt_repls
:
list
[
PromptReplacement
],
prompt_repls
:
list
[
PromptReplacement
],
)
->
dict
[
str
,
list
[
_
BoundPromptReplacement
]]:
)
->
dict
[
str
,
list
[
BoundPromptReplacement
]]:
tokenizer
=
self
.
_
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_repls
)
it
=
(
prompt_repl
.
bind
(
tokenizer
)
for
prompt_repl
in
prompt_repls
)
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
...
@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -859,10 +886,10 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_apply_prompt_replacements
(
def
_apply_prompt_replacements
(
self
,
self
,
token_ids
:
list
[
int
],
token_ids
:
list
[
int
],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
_
BoundPromptReplacement
]],
mm_prompt_repls
:
Mapping
[
str
,
Sequence
[
BoundPromptReplacement
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
_
PlaceholderInfo
]]]:
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderInfo
]]]:
tokenizer
=
self
.
_
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
mm_token_matches
=
{
mm_token_matches
=
{
modality
:
find_token_matches
(
token_ids
,
prompt_repls
)
modality
:
find_token_matches
(
token_ids
,
prompt_repls
)
...
@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -950,7 +977,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
def
_validate_mm_placeholders
(
def
_validate_mm_placeholders
(
self
,
self
,
mm_placeholders
:
Mapping
[
str
,
list
[
_
PlaceholderInfo
]],
mm_placeholders
:
Mapping
[
str
,
list
[
PlaceholderInfo
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
*
,
*
,
allow_missing
:
bool
=
False
,
allow_missing
:
bool
=
False
,
...
@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -1001,7 +1028,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# instead of rehashing.
# instead of rehashing.
if
envs
.
VLLM_USE_V1
:
if
envs
.
VLLM_USE_V1
:
model_id
=
self
.
ctx
.
model_
config
.
model
model_id
=
self
.
info
.
model_
id
mm_hashes
=
{
mm_hashes
=
{
modality
:
[
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
...
@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -1046,7 +1073,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
allow_missing
=
True
,
allow_missing
=
True
,
)
)
mm_missing_repls
=
dict
[
str
,
list
[
_
BoundPromptReplacement
]]()
mm_missing_repls
=
dict
[
str
,
list
[
BoundPromptReplacement
]]()
for
modality
,
missing_repl_count
in
mm_missing_repl_counts
.
items
():
for
modality
,
missing_repl_count
in
mm_missing_repl_counts
.
items
():
if
missing_repl_count
==
0
:
if
missing_repl_count
==
0
:
mm_missing_repls
[
modality
]
=
[]
mm_missing_repls
[
modality
]
=
[]
...
@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -1059,7 +1086,7 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
# If HF processor already inserts placeholder tokens,
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
# there is no need for us to insert them
if
all
(
len
(
repls
)
==
0
for
repls
in
mm_missing_repls
.
items
()):
if
all
(
len
(
repls
)
==
0
for
repls
in
mm_missing_repls
.
items
()):
tokenizer
=
self
.
_
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
prompt_text
=
decode_tokens
(
tokenizer
,
prompt_ids
)
prompt_text
=
decode_tokens
(
tokenizer
,
prompt_ids
)
mm_placeholders
=
hf_mm_placeholders
mm_placeholders
=
hf_mm_placeholders
else
:
else
:
...
@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
...
@@ -1090,79 +1117,3 @@ class BaseMultiModalProcessor(ProcessingMixin, ABC):
mm_hashes
=
mm_hashes
,
mm_hashes
=
mm_hashes
,
mm_placeholders
=
mm_placeholder_ranges
,
mm_placeholders
=
mm_placeholder_ranges
,
)
)
def
_get_dummy_mm_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalInputsV2
:
profiling
=
self
.
profiling_info
processor_inputs
=
profiling
.
get_dummy_processor_inputs
(
seq_len
,
mm_counts
)
return
self
.
apply
(
prompt_text
=
processor_inputs
.
prompt_text
,
mm_data
=
processor_inputs
.
mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
profiling
=
self
.
profiling_info
mm_counts
=
profiling
.
get_mm_limits
()
mm_max_tokens_per_item
=
profiling
.
get_mm_max_tokens_per_item
(
seq_len
)
if
mm_counts
.
keys
()
!=
mm_max_tokens_per_item
.
keys
():
raise
AssertionError
(
"The keys returned by `get_supported_mm_limits`"
f
"(
{
set
(
mm_counts
.
keys
())
}
) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
expected_placeholders_by_modality
=
{
modality
:
mm_max_tokens_per_item
[
modality
]
*
mm_counts
[
modality
]
for
modality
in
placeholders_by_modality
}
if
total_placeholders_by_modality
!=
expected_placeholders_by_modality
:
raise
AssertionError
(
f
"The processed dummy data has a total of "
f
"
{
total_placeholders_by_modality
}
placeholder tokens, which "
f
"is not the expected
{
expected_placeholders_by_modality
}
"
"tokens."
)
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
len
(
prompt_token_ids
)))
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
),
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_placeholders
=
placeholders_by_modality
,
)
vllm/multimodal/profiling.py
View file @
2a0596bc
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Mapping
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Generic
,
TypeVar
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
from
PIL
import
Image
from
PIL
import
Image
from
vllm.inputs
import
InputProcessingContext
import
vllm.envs
as
envs
from
vllm.inputs
import
DummyData
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.inputs
import
MultiModalDataDict
from
.inputs
import
MultiModalDataDict
,
MultiModalInputsV2
from
.processing
import
BaseMultiModalProcessor
,
BaseProcessingInfo
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -23,39 +25,19 @@ class ProcessorInputs:
...
@@ -23,39 +25,19 @@ class ProcessorInputs:
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
]
=
field
(
default_factory
=
dict
)
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
]
=
field
(
default_factory
=
dict
)
class
BaseProfilingInfo
(
ABC
):
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
"""
Abstract base class that provides the information necessary to profile
multi-modal models.
"""
def
__init__
(
self
,
ctx
:
InputProcessingContext
)
->
None
:
super
().
__init__
()
self
.
ctx
=
ctx
@
abstractmethod
class
BaseDummyInputsBuilder
(
ABC
,
Generic
[
_I
]):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
"""
"""
Return the maximum supported number of items for each modality.
Abstract base class that constructs the dummy data to profile
multi-modal models.
A value of `None` means unlimited number of items.
Omitting a modality from the returned dictionary means that
it is not supported at all.
"""
"""
raise
NotImplementedError
@
abstractmethod
def
__init__
(
self
,
info
:
_I
)
->
None
:
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
)
->
Mapping
[
str
,
int
]:
super
().
__init__
()
"""
Get the maximum possible number of tokens per data item
for each modality.
The dictionary returned by this method should have the same
self
.
info
=
info
keys as that returned by :meth:`get_supported_mm_limits`.
"""
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
get_dummy_processor_inputs
(
def
get_dummy_processor_inputs
(
...
@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
...
@@ -64,8 +46,8 @@ class BaseProfilingInfo(ABC):
mm_counts
:
Mapping
[
str
,
int
],
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
"""
"""
Build the
multi-modal portion of the
input which, after processing,
Build the input which, after processing,
results in
results in `mm_max_tokens` in :meth:`
get_mm_max_tokens_per_item
`
.
`self.info.
get_mm_max_tokens_per_item
()` placeholder tokens
.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
...
@@ -99,11 +81,33 @@ class BaseProfilingInfo(ABC):
video
=
np
.
zeros
((
num_frames
,
width
,
height
,
3
))
video
=
np
.
zeros
((
num_frames
,
width
,
height
,
3
))
return
[
video
]
*
num_videos
return
[
video
]
*
num_videos
def
get_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
mm_config
=
self
.
ctx
.
get_mm_config
()
class
MultiModalProfiler
(
Generic
[
_I
]):
"""
Contains code for running memory profiling for multi-modal models.
"""
def
__init__
(
self
,
processor
:
BaseMultiModalProcessor
[
_I
],
)
->
None
:
super
().
__init__
()
self
.
processor
=
processor
@
property
def
processing_info
(
self
)
->
BaseProcessingInfo
:
return
self
.
processor
.
info
@
property
def
dummy_inputs
(
self
)
->
BaseDummyInputsBuilder
[
_I
]:
return
self
.
processor
.
dummy_inputs
def
_get_mm_limits
(
self
)
->
Mapping
[
str
,
int
]:
mm_config
=
self
.
processing_info
.
ctx
.
get_mm_config
()
mm_limit_per_prompt
=
mm_config
.
limit_per_prompt
mm_limit_per_prompt
=
mm_config
.
limit_per_prompt
supported_mm_limits
=
self
.
get_supported_mm_limits
()
supported_mm_limits
=
self
.
processing_info
.
get_supported_mm_limits
()
mm_limits
=
{
mm_limits
=
{
modality
:
mm_limit_per_prompt
.
get
(
modality
,
1
)
modality
:
mm_limit_per_prompt
.
get
(
modality
,
1
)
...
@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
...
@@ -119,3 +123,81 @@ class BaseProfilingInfo(ABC):
f
"at most
{
supported_limit
}
{
modality
}
items."
)
f
"at most
{
supported_limit
}
{
modality
}
items."
)
return
mm_limits
return
mm_limits
def
_get_dummy_mm_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalInputsV2
:
factory
=
self
.
dummy_inputs
processor_inputs
=
factory
.
get_dummy_processor_inputs
(
seq_len
,
mm_counts
)
return
self
.
processor
.
apply
(
prompt_text
=
processor_inputs
.
prompt_text
,
mm_data
=
processor_inputs
.
mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
)
def
get_dummy_data
(
self
,
seq_len
:
int
)
->
DummyData
:
# Avoid circular import
from
vllm.sequence
import
SequenceData
mm_counts
=
self
.
_get_mm_limits
()
info
=
self
.
processing_info
mm_max_tokens_per_item
=
info
.
get_mm_max_tokens_per_item
(
seq_len
)
if
mm_counts
.
keys
()
!=
mm_max_tokens_per_item
.
keys
():
raise
AssertionError
(
"The keys returned by `get_supported_mm_limits`"
f
"(
{
set
(
mm_counts
.
keys
())
}
) should be the same as those "
"returned by `get_mm_max_tokens_per_item` "
f
"(
{
set
(
mm_max_tokens_per_item
.
keys
())
}
)"
)
mm_inputs
=
self
.
_get_dummy_mm_inputs
(
seq_len
,
mm_counts
)
prompt_token_ids
=
mm_inputs
[
"prompt_token_ids"
]
placeholders_by_modality
=
mm_inputs
[
"mm_placeholders"
]
total_placeholders_by_modality
=
{
modality
:
sum
(
item
[
"length"
]
for
item
in
placeholders
)
for
modality
,
placeholders
in
placeholders_by_modality
.
items
()
}
expected_placeholders_by_modality
=
{
modality
:
mm_max_tokens_per_item
[
modality
]
*
mm_counts
[
modality
]
for
modality
in
placeholders_by_modality
}
if
total_placeholders_by_modality
!=
expected_placeholders_by_modality
:
raise
AssertionError
(
f
"The processed dummy data has a total of "
f
"
{
total_placeholders_by_modality
}
placeholder tokens, which "
f
"is not the expected
{
expected_placeholders_by_modality
}
"
"tokens."
)
total_len
=
len
(
prompt_token_ids
)
# V0 does not support chunked prefill.
if
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
logger
.
warning
(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)),
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
)
prompt_token_ids
.
extend
([
0
]
*
(
seq_len
-
len
(
prompt_token_ids
)))
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
),
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_placeholders
=
placeholders_by_modality
,
)
vllm/multimodal/registry.py
View file @
2a0596bc
import
functools
import
functools
from
collections
import
UserDict
from
collections
import
UserDict
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Mapping
,
Optional
,
Protocol
,
from
dataclasses
import
dataclass
Sequence
,
Type
,
TypeVar
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
Mapping
,
Optional
,
Protocol
,
Sequence
,
Type
,
TypeVar
)
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -14,7 +15,9 @@ from .audio import AudioPlugin
...
@@ -14,7 +15,9 @@ from .audio import AudioPlugin
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
from
.base
import
MultiModalInputMapper
,
MultiModalPlugin
,
MultiModalTokensCalc
from
.image
import
ImagePlugin
from
.image
import
ImagePlugin
from
.inputs
import
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
from
.inputs
import
MultiModalDataDict
,
MultiModalKwargs
,
NestedTensors
from
.processing
import
BaseMultiModalProcessor
,
ProcessingCache
from
.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
ProcessingCache
)
from
.profiling
import
BaseDummyInputsBuilder
from
.utils
import
cached_get_tokenizer
from
.utils
import
cached_get_tokenizer
from
.video
import
VideoPlugin
from
.video
import
VideoPlugin
...
@@ -27,20 +30,59 @@ logger = init_logger(__name__)
...
@@ -27,20 +30,59 @@ logger = init_logger(__name__)
MM_CACHE_SIZE
=
256
MM_CACHE_SIZE
=
256
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
N
=
TypeVar
(
"N"
,
bound
=
Type
[
nn
.
Module
])
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
_I_co
=
TypeVar
(
"_I_co"
,
bound
=
BaseProcessingInfo
,
covariant
=
True
)
class
MultiModalProcessor
Factory
(
Protocol
):
class
ProcessingInfo
Factory
(
Protocol
[
_I_co
]
):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def
__call__
(
def
__call__
(
self
,
self
,
ctx
:
InputProcessingContext
,
ctx
:
InputProcessingContext
,
)
->
_I_co
:
...
class
DummyInputsBuilderFactory
(
Protocol
[
_I
]):
"""
Constructs a :class:`BaseDummyInputsBuilder` instance from the context.
"""
def
__call__
(
self
,
info
:
_I
)
->
BaseDummyInputsBuilder
[
_I
]:
...
class
MultiModalProcessorFactory
(
Protocol
[
_I
]):
"""Constructs a :class:`MultiModalProcessor` instance from the context."""
def
__call__
(
self
,
info
:
_I
,
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
[
_I
]
:
...
...
@
dataclass
(
frozen
=
True
)
class
_ProcessorFactories
(
Generic
[
_I
]):
info
:
ProcessingInfoFactory
[
_I
]
processor
:
MultiModalProcessorFactory
[
_I
]
dummy_inputs
:
DummyInputsBuilderFactory
[
_I
]
def
build_processor
(
self
,
ctx
:
InputProcessingContext
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
):
info
=
self
.
info
(
ctx
)
dummy_inputs_builder
=
self
.
dummy_inputs
(
info
)
return
self
.
processor
(
info
,
dummy_inputs_builder
,
cache
=
cache
)
class
_MultiModalLimits
(
UserDict
[
"ModelConfig"
,
Dict
[
str
,
int
]]):
class
_MultiModalLimits
(
UserDict
[
"ModelConfig"
,
Dict
[
str
,
int
]]):
"""
"""
Wraps `_limits_by_model` for a more informative error message
Wraps `_limits_by_model` for a more informative error message
...
@@ -71,7 +113,7 @@ class MultiModalRegistry:
...
@@ -71,7 +113,7 @@ class MultiModalRegistry:
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_plugins
=
{
p
.
get_data_key
():
p
for
p
in
plugins
}
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
self
.
_processor_factories
=
ClassRegistry
[
nn
.
Module
,
MultiModal
ProcessorFactor
y
]()
_
ProcessorFactor
ies
]()
# This is used for non-multimodal models
# This is used for non-multimodal models
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
self
.
_disabled_limits_per_plugin
=
{
k
:
0
for
k
in
self
.
_plugins
}
...
@@ -224,7 +266,7 @@ class MultiModalRegistry:
...
@@ -224,7 +266,7 @@ class MultiModalRegistry:
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
processor
=
self
.
create_processor
(
model_config
,
tokenizer
)
processor
=
self
.
create_processor
(
model_config
,
tokenizer
)
seq_len
=
model_config
.
max_model_len
seq_len
=
model_config
.
max_model_len
return
processor
.
profiling_
info
.
get_mm_max_tokens_per_item
(
seq_len
)
return
processor
.
info
.
get_mm_max_tokens_per_item
(
seq_len
)
return
{
return
{
key
:
plugin
.
get_max_multimodal_tokens
(
model_config
)
key
:
plugin
.
get_max_multimodal_tokens
(
model_config
)
...
@@ -315,7 +357,10 @@ class MultiModalRegistry:
...
@@ -315,7 +357,10 @@ class MultiModalRegistry:
def
register_processor
(
def
register_processor
(
self
,
self
,
factory
:
MultiModalProcessorFactory
,
processor
:
MultiModalProcessorFactory
[
_I
],
*
,
info
:
ProcessingInfoFactory
[
_I
],
dummy_inputs
:
DummyInputsBuilderFactory
[
_I
],
):
):
"""
"""
Register a multi-modal processor to a model class. The processor
Register a multi-modal processor to a model class. The processor
...
@@ -336,7 +381,11 @@ class MultiModalRegistry:
...
@@ -336,7 +381,11 @@ class MultiModalRegistry:
"registered to %s. It is overwritten by the new one."
,
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
model_cls
,
self
)
self
.
_processor_factories
[
model_cls
]
=
factory
self
.
_processor_factories
[
model_cls
]
=
_ProcessorFactories
(
info
=
info
,
dummy_inputs
=
dummy_inputs
,
processor
=
processor
,
)
return
model_cls
return
model_cls
...
@@ -359,15 +408,15 @@ class MultiModalRegistry:
...
@@ -359,15 +408,15 @@ class MultiModalRegistry:
self
,
self
,
model_config
:
"ModelConfig"
,
model_config
:
"ModelConfig"
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
[
BaseProcessingInfo
]
:
"""
"""
Create a multi-modal processor for a specific model and tokenizer.
Create a multi-modal processor for a specific model and tokenizer.
"""
"""
model_cls
=
self
.
_get_model_cls
(
model_config
)
model_cls
=
self
.
_get_model_cls
(
model_config
)
processor_
factor
y
=
self
.
_processor_factories
[
model_cls
]
factor
ies
=
self
.
_processor_factories
[
model_cls
]
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
)
ctx
=
InputProcessingContext
(
model_config
,
tokenizer
)
cache
=
(
None
if
model_config
.
disable_mm_preprocessor_cache
else
cache
=
(
None
if
model_config
.
disable_mm_preprocessor_cache
else
self
.
_processing_cache
)
self
.
_processing_cache
)
return
processor
_factory
(
ctx
,
cache
=
cache
)
return
factories
.
build_
processor
(
ctx
,
cache
=
cache
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment