Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4ce74c1
Unverified
Commit
a4ce74c1
authored
Feb 06, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 05, 2025
Browse files
[VLM] Use shared field to pass token ids to model
parent
3b2005e1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
235 additions
and
46 deletions
+235
-46
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-3
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+232
-43
No files found.
vllm/model_executor/models/internvl.py
View file @
a4ce74c1
...
...
@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs
[
"image_token_id"
]
=
[
image_token_id
]
*
len
(
image_data
)
processed_outputs
[
"image_token_id"
]
=
torch
.
tensor
(
image_token_id
)
return
processed_outputs
...
...
@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_num_patches
=
hf_inputs
.
get
(
"image_num_patches"
,
torch
.
empty
(
0
))
num_images
=
len
(
image_num_patches
)
return
dict
(
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_patches
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
batch
ed
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
shar
ed
(
"image"
,
num_images
),
)
def
_get_prompt_replacements
(
...
...
vllm/multimodal/inputs.py
View file @
a4ce74c1
...
...
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
functools
import
partial
from
itertools
import
accumulate
from
typing
import
(
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
...
...
@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
@
dataclass
(
frozen
=
True
)
class
MultiModalFieldElem
:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
field
:
"BaseMultiModalField"
"""
Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality
:
str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key
:
str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data
:
NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field
:
"BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
self
.
__class__
):
return
False
return
(
self
.
field
==
other
.
field
and
nested_tensors_equal
(
self
.
data
,
other
.
data
))
return
((
self
.
modality
,
self
.
key
)
==
(
other
.
modality
,
other
.
key
)
and
nested_tensors_equal
(
self
.
data
,
other
.
data
)
and
type
(
self
.
field
)
==
type
(
other
.
field
))
# noqa: E721
@
dataclass
(
frozen
=
True
)
class
BaseMultiModalField
(
ABC
):
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
key
:
str
modality
:
str
"""
Defines how to interpret tensor data belonging to a keyword argument in
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def
_field_factory
(
self
,
*
,
modality
:
str
,
key
:
str
):
f
=
partial
(
MultiModalFieldElem
,
modality
=
modality
,
key
=
key
,
field
=
self
,
)
# Allow passing data as positional argument
def
factory
(
data
:
NestedTensors
)
->
MultiModalFieldElem
:
return
f
(
data
=
data
)
return
factory
@
abstractmethod
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
def
build_elems
(
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise
NotImplementedError
def
_build_elem
(
self
,
data
:
NestedTensors
)
->
MultiModalFieldElem
:
return
MultiModalFieldElem
(
self
,
data
)
@
abstractmethod
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
raise
NotImplementedError
def
reduce
(
self
,
batch
:
list
[
MultiModalFieldElem
])
->
MultiModalFieldElem
:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields
=
[
item
.
field
for
item
in
batch
]
if
len
(
set
(
fields
))
>
1
:
raise
ValueError
(
f
"Cannot merge different
{
fields
=
}
"
)
def
reduce_data
(
self
,
elems
:
list
[
MultiModalFieldElem
])
->
NestedTensors
:
"""
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
data
=
self
.
_reduce_data
([
item
.
data
for
item
in
batch
])
This is the inverse of :meth:`build_elems`.
"""
field_types
=
[
type
(
item
.
field
)
for
item
in
elems
]
if
len
(
set
(
field_types
))
>
1
:
raise
ValueError
(
f
"Cannot merge different
{
field_types
=
}
"
)
return
self
.
_
build_el
em
(
data
)
return
self
.
_
reduce_data
([
it
em
.
data
for
item
in
elems
]
)
@
dataclass
(
frozen
=
True
)
class
MultiModalBatchedField
(
BaseMultiModalField
):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by indexing into the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.batched`
"""
def
build_elems
(
self
,
batch
:
NestedTensors
)
->
list
[
MultiModalFieldElem
]:
return
[
self
.
_build_elem
(
item
)
for
item
in
batch
]
def
build_elems
(
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
item
)
for
item
in
data
]
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
...
...
@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
@
dataclass
(
frozen
=
True
)
class
MultiModalFlatField
(
BaseMultiModalField
):
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
is obtained by slicing along the first dimension of the underlying data.
See also:
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
slices
:
Sequence
[
slice
]
def
build_elems
(
self
,
batch
:
NestedTensors
,
slices
:
Sequence
[
slice
],
)
->
list
[
MultiModalFieldElem
]:
return
[
self
.
_build_elem
(
batch
[
slice_
])
for
slice_
in
slices
]
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
data
[
s
])
for
s
in
self
.
slices
]
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
...
...
@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
return
[
e
for
elem
in
batch
for
e
in
elem
]
@
dataclass
(
frozen
=
True
)
class
MultiModalSharedField
(
BaseMultiModalField
):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size
:
int
def
build_elems
(
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
data
)]
*
self
.
batch_size
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
return
batch
[
0
]
class
MultiModalFieldConfig
:
@
staticmethod
def
batched
(
modality
:
str
):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return
MultiModalFieldConfig
(
field
_cls
=
MultiModalBatchedField
,
field
=
MultiModalBatchedField
()
,
modality
=
modality
,
)
@
staticmethod
def
flat
(
modality
:
str
,
slices
:
Sequence
[
slice
]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return
MultiModalFieldConfig
(
field
_cls
=
MultiModalFlatField
,
field
=
MultiModalFlatField
(
slices
=
slices
)
,
modality
=
modality
,
slices
=
slices
,
)
@
staticmethod
def
flat_from_sizes
(
modality
:
str
,
size_per_item
:
torch
.
Tensor
):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
slice_idxs
=
[
0
,
*
accumulate
(
size_per_item
)]
slices
=
[
slice
(
slice_idxs
[
i
],
slice_idxs
[
i
+
1
])
...
...
@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
return
MultiModalFieldConfig
.
flat
(
modality
,
slices
)
def
__init__
(
self
,
field_cls
:
type
[
BaseMultiModalField
],
modality
:
str
,
**
field_config
:
Any
,
)
->
None
:
@
staticmethod
def
shared
(
modality
:
str
,
batch_size
:
int
):
"""
Defines a field where an element in the batch is obtained by
taking the entirety of the underlying data.
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return
MultiModalFieldConfig
(
field
=
MultiModalSharedField
(
batch_size
),
modality
=
modality
,
)
def
__init__
(
self
,
field
:
BaseMultiModalField
,
modality
:
str
)
->
None
:
super
().
__init__
()
self
.
field
_cls
=
field
_cls
self
.
field
=
field
self
.
modality
=
modality
self
.
field_config
=
field_config
def
build_elems
(
self
,
key
:
str
,
batch
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field
=
self
.
field_cls
(
key
=
key
,
modality
=
self
.
modality
)
return
field
.
build_elems
(
batch
,
**
self
.
field_config
)
# type: ignore
return
self
.
field
.
build_elems
(
self
.
modality
,
key
,
batch
)
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
...
...
@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@
staticmethod
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
return
MultiModalKwargsItem
({
elem
.
field
.
key
:
elem
for
elem
in
elems
})
return
MultiModalKwargsItem
({
elem
.
key
:
elem
for
elem
in
elems
})
@
property
def
modality
(
self
)
->
str
:
modalities
=
{
elem
.
field
.
modality
for
elem
in
self
.
data
.
values
()}
modalities
=
{
elem
.
modality
for
elem
in
self
.
data
.
values
()}
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
return
next
(
iter
(
modalities
))
...
...
@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key
[
key
].
append
(
elem
)
data
=
{
key
:
elems
[
0
].
field
.
reduce
(
elems
)
.
data
key
:
elems
[
0
].
field
.
reduce
_data
(
elems
)
for
key
,
elems
in
elems_by_key
.
items
()
if
len
(
elems
)
>
0
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment