Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4ce74c1
Unverified
Commit
a4ce74c1
authored
Feb 06, 2025
by
Cyrus Leung
Committed by
GitHub
Feb 05, 2025
Browse files
[VLM] Use shared field to pass token ids to model
parent
3b2005e1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
235 additions
and
46 deletions
+235
-46
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-3
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+232
-43
No files found.
vllm/model_executor/models/internvl.py
View file @
a4ce74c1
...
@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Since there may be extra tokens in the feature placeholders,
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
# tokens to merge from the vision encoder outputs
processed_outputs
[
"image_token_id"
]
=
[
image_token_id
processed_outputs
[
"image_token_id"
]
=
torch
.
tensor
(
image_token_id
)
]
*
len
(
image_data
)
return
processed_outputs
return
processed_outputs
...
@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
image_num_patches
=
hf_inputs
.
get
(
"image_num_patches"
,
torch
.
empty
(
0
))
image_num_patches
=
hf_inputs
.
get
(
"image_num_patches"
,
torch
.
empty
(
0
))
num_images
=
len
(
image_num_patches
)
return
dict
(
return
dict
(
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values_flat
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
image_num_patches
),
"image"
,
image_num_patches
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_num_patches
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
batch
ed
(
"image"
),
image_token_id
=
MultiModalFieldConfig
.
shar
ed
(
"image"
,
num_images
),
)
)
def
_get_prompt_replacements
(
def
_get_prompt_replacements
(
...
...
vllm/multimodal/inputs.py
View file @
a4ce74c1
...
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
...
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
collections.abc
import
Mapping
,
Sequence
from
collections.abc
import
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
partial
from
itertools
import
accumulate
from
itertools
import
accumulate
from
typing
import
(
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
TypedDict
,
TypeVar
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
Union
,
cast
,
final
)
...
@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
...
@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
MultiModalFieldElem
:
class
MultiModalFieldElem
:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
"""
field
:
"BaseMultiModalField"
Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality
:
str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key
:
str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data
:
NestedTensors
data
:
NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field
:
"BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
self
.
__class__
):
if
not
isinstance
(
other
,
self
.
__class__
):
return
False
return
False
return
(
self
.
field
==
other
.
field
return
((
self
.
modality
,
self
.
key
)
==
(
other
.
modality
,
other
.
key
)
and
nested_tensors_equal
(
self
.
data
,
other
.
data
))
and
nested_tensors_equal
(
self
.
data
,
other
.
data
)
and
type
(
self
.
field
)
==
type
(
other
.
field
))
# noqa: E721
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
BaseMultiModalField
(
ABC
):
class
BaseMultiModalField
(
ABC
):
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
"""
key
:
str
Defines how to interpret tensor data belonging to a keyword argument in
modality
:
str
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def
_field_factory
(
self
,
*
,
modality
:
str
,
key
:
str
):
f
=
partial
(
MultiModalFieldElem
,
modality
=
modality
,
key
=
key
,
field
=
self
,
)
# Allow passing data as positional argument
def
factory
(
data
:
NestedTensors
)
->
MultiModalFieldElem
:
return
f
(
data
=
data
)
return
factory
@
abstractmethod
@
abstractmethod
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
def
build_elems
(
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise
NotImplementedError
raise
NotImplementedError
def
_build_elem
(
self
,
data
:
NestedTensors
)
->
MultiModalFieldElem
:
@
abstractmethod
return
MultiModalFieldElem
(
self
,
data
)
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
raise
NotImplementedError
def
reduce
(
self
,
batch
:
list
[
MultiModalFieldElem
])
->
MultiModalFieldElem
:
def
reduce_data
(
self
,
elems
:
list
[
MultiModalFieldElem
])
->
NestedTensors
:
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
"""
fields
=
[
item
.
field
for
item
in
batch
]
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
if
len
(
set
(
fields
))
>
1
:
raise
ValueError
(
f
"Cannot merge different
{
fields
=
}
"
)
data
=
self
.
_reduce_data
([
item
.
data
for
item
in
batch
])
This is the inverse of :meth:`build_elems`.
"""
field_types
=
[
type
(
item
.
field
)
for
item
in
elems
]
if
len
(
set
(
field_types
))
>
1
:
raise
ValueError
(
f
"Cannot merge different
{
field_types
=
}
"
)
return
self
.
_
build_el
em
(
data
)
return
self
.
_
reduce_data
([
it
em
.
data
for
item
in
elems
]
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
MultiModalBatchedField
(
BaseMultiModalField
):
class
MultiModalBatchedField
(
BaseMultiModalField
):
"""
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
See also:
is obtained by indexing into the first dimension of the underlying data.
:func:`MultiModalFieldConfig.batched`
"""
"""
def
build_elems
(
self
,
batch
:
NestedTensors
)
->
list
[
MultiModalFieldElem
]:
def
build_elems
(
return
[
self
.
_build_elem
(
item
)
for
item
in
batch
]
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
item
)
for
item
in
data
]
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
...
@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
...
@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
MultiModalFlatField
(
BaseMultiModalField
):
class
MultiModalFlatField
(
BaseMultiModalField
):
"""
"""
A :class:`BaseMultiModalField` implementation where an element in the batch
See also:
is obtained by slicing along the first dimension of the underlying data.
:func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
"""
"""
slices
:
Sequence
[
slice
]
def
build_elems
(
def
build_elems
(
self
,
self
,
batch
:
NestedTensors
,
modality
:
str
,
slices
:
Sequence
[
slice
],
key
:
str
,
)
->
list
[
MultiModalFieldElem
]:
data
:
NestedTensors
,
return
[
self
.
_build_elem
(
batch
[
slice_
])
for
slice_
in
slices
]
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
data
[
s
])
for
s
in
self
.
slices
]
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
...
@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
...
@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
return
[
e
for
elem
in
batch
for
e
in
elem
]
return
[
e
for
elem
in
batch
for
e
in
elem
]
@
dataclass
(
frozen
=
True
)
class
MultiModalSharedField
(
BaseMultiModalField
):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size
:
int
def
build_elems
(
self
,
modality
:
str
,
key
:
str
,
data
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
field_factory
=
self
.
_field_factory
(
modality
=
modality
,
key
=
key
)
return
[
field_factory
(
data
)]
*
self
.
batch_size
def
_reduce_data
(
self
,
batch
:
list
[
NestedTensors
])
->
NestedTensors
:
return
batch
[
0
]
class
MultiModalFieldConfig
:
class
MultiModalFieldConfig
:
@
staticmethod
@
staticmethod
def
batched
(
modality
:
str
):
def
batched
(
modality
:
str
):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return
MultiModalFieldConfig
(
return
MultiModalFieldConfig
(
field
_cls
=
MultiModalBatchedField
,
field
=
MultiModalBatchedField
()
,
modality
=
modality
,
modality
=
modality
,
)
)
@
staticmethod
@
staticmethod
def
flat
(
modality
:
str
,
slices
:
Sequence
[
slice
]):
def
flat
(
modality
:
str
,
slices
:
Sequence
[
slice
]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return
MultiModalFieldConfig
(
return
MultiModalFieldConfig
(
field
_cls
=
MultiModalFlatField
,
field
=
MultiModalFlatField
(
slices
=
slices
)
,
modality
=
modality
,
modality
=
modality
,
slices
=
slices
,
)
)
@
staticmethod
@
staticmethod
def
flat_from_sizes
(
modality
:
str
,
size_per_item
:
torch
.
Tensor
):
def
flat_from_sizes
(
modality
:
str
,
size_per_item
:
torch
.
Tensor
):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
slice_idxs
=
[
0
,
*
accumulate
(
size_per_item
)]
slice_idxs
=
[
0
,
*
accumulate
(
size_per_item
)]
slices
=
[
slices
=
[
slice
(
slice_idxs
[
i
],
slice_idxs
[
i
+
1
])
slice
(
slice_idxs
[
i
],
slice_idxs
[
i
+
1
])
...
@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
...
@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
return
MultiModalFieldConfig
.
flat
(
modality
,
slices
)
return
MultiModalFieldConfig
.
flat
(
modality
,
slices
)
def
__init__
(
@
staticmethod
self
,
def
shared
(
modality
:
str
,
batch_size
:
int
):
field_cls
:
type
[
BaseMultiModalField
],
"""
modality
:
str
,
Defines a field where an element in the batch is obtained by
**
field_config
:
Any
,
taking the entirety of the underlying data.
)
->
None
:
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return
MultiModalFieldConfig
(
field
=
MultiModalSharedField
(
batch_size
),
modality
=
modality
,
)
def
__init__
(
self
,
field
:
BaseMultiModalField
,
modality
:
str
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
field
_cls
=
field
_cls
self
.
field
=
field
self
.
modality
=
modality
self
.
modality
=
modality
self
.
field_config
=
field_config
def
build_elems
(
def
build_elems
(
self
,
self
,
key
:
str
,
key
:
str
,
batch
:
NestedTensors
,
batch
:
NestedTensors
,
)
->
Sequence
[
MultiModalFieldElem
]:
)
->
Sequence
[
MultiModalFieldElem
]:
field
=
self
.
field_cls
(
key
=
key
,
modality
=
self
.
modality
)
return
self
.
field
.
build_elems
(
self
.
modality
,
key
,
batch
)
return
field
.
build_elems
(
batch
,
**
self
.
field_config
)
# type: ignore
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
class
MultiModalKwargsItem
(
UserDict
[
str
,
MultiModalFieldElem
]):
...
@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
...
@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@
staticmethod
@
staticmethod
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
def
from_elems
(
elems
:
Sequence
[
MultiModalFieldElem
]):
return
MultiModalKwargsItem
({
elem
.
field
.
key
:
elem
for
elem
in
elems
})
return
MultiModalKwargsItem
({
elem
.
key
:
elem
for
elem
in
elems
})
@
property
@
property
def
modality
(
self
)
->
str
:
def
modality
(
self
)
->
str
:
modalities
=
{
elem
.
field
.
modality
for
elem
in
self
.
data
.
values
()}
modalities
=
{
elem
.
modality
for
elem
in
self
.
data
.
values
()}
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
assert
len
(
modalities
)
==
1
,
f
"Found different modalities=
{
modalities
}
"
return
next
(
iter
(
modalities
))
return
next
(
iter
(
modalities
))
...
@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
...
@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key
[
key
].
append
(
elem
)
elems_by_key
[
key
].
append
(
elem
)
data
=
{
data
=
{
key
:
elems
[
0
].
field
.
reduce
(
elems
)
.
data
key
:
elems
[
0
].
field
.
reduce
_data
(
elems
)
for
key
,
elems
in
elems_by_key
.
items
()
if
len
(
elems
)
>
0
for
key
,
elems
in
elems_by_key
.
items
()
if
len
(
elems
)
>
0
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment