Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fab5f53e
Unverified
Commit
fab5f53e
authored
Aug 27, 2024
by
Peter Salas
Committed by
GitHub
Aug 28, 2024
Browse files
[Core][VLM] Stack multimodal tensors to represent multiple images within each prompt (#7902)
parent
9c71c97a
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
214 additions
and
60 deletions
+214
-60
docs/source/dev/multimodal/multimodal_index.rst
docs/source/dev/multimodal/multimodal_index.rst
+0
-2
tests/multimodal/test_base.py
tests/multimodal/test_base.py
+83
-0
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+7
-0
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+3
-0
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+3
-0
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+9
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+8
-0
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+11
-0
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+8
-3
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+8
-0
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+8
-0
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+9
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+37
-23
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+1
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+19
-30
No files found.
docs/source/dev/multimodal/multimodal_index.rst
View file @
fab5f53e
...
...
@@ -45,8 +45,6 @@ Base Classes
.. autodata:: vllm.multimodal.NestedTensors
.. autodata:: vllm.multimodal.BatchedTensors
.. autodata:: vllm.multimodal.BatchedTensorInputs
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
...
...
tests/multimodal/test_base.py
0 → 100644
View file @
fab5f53e
import
torch
from
vllm.multimodal.base
import
MultiModalInputs
,
NestedTensors
def
assert_nested_tensors_equal
(
expected
:
NestedTensors
,
actual
:
NestedTensors
):
assert
type
(
expected
)
==
type
(
actual
)
if
isinstance
(
expected
,
torch
.
Tensor
):
assert
torch
.
equal
(
expected
,
actual
)
else
:
for
expected_item
,
actual_item
in
zip
(
expected
,
actual
):
assert_nested_tensors_equal
(
expected_item
,
actual_item
)
def
assert_multimodal_inputs_equal
(
expected
:
MultiModalInputs
,
actual
:
MultiModalInputs
):
assert
set
(
expected
.
keys
())
==
set
(
actual
.
keys
())
for
key
in
expected
:
assert_nested_tensors_equal
(
expected
[
key
],
actual
[
key
])
def
test_multimodal_input_batch_single_tensor
():
t
=
torch
.
rand
([
1
,
2
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
t
}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
t
.
unsqueeze
(
0
)})
def
test_multimodal_input_batch_multiple_tensors
():
a
=
torch
.
rand
([
1
,
1
,
2
])
b
=
torch
.
rand
([
1
,
1
,
2
])
c
=
torch
.
rand
([
1
,
1
,
2
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
a
},
{
"image"
:
b
},
{
"image"
:
c
}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
torch
.
stack
([
a
,
b
,
c
])})
def
test_multimodal_input_batch_multiple_heterogeneous_tensors
():
a
=
torch
.
rand
([
1
,
2
,
2
])
b
=
torch
.
rand
([
1
,
3
,
2
])
c
=
torch
.
rand
([
1
,
4
,
2
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
a
},
{
"image"
:
b
},
{
"image"
:
c
}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
[
a
,
b
,
c
]})
def
test_multimodal_input_batch_nested_tensors
():
a
=
torch
.
rand
([
2
,
3
])
b
=
torch
.
rand
([
2
,
3
])
c
=
torch
.
rand
([
2
,
3
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
[
a
]
},
{
"image"
:
[
b
]
},
{
"image"
:
[
c
]
}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
torch
.
stack
([
a
.
unsqueeze
(
0
),
b
.
unsqueeze
(
0
),
c
.
unsqueeze
(
0
)])
})
def
test_multimodal_input_batch_heterogeneous_lists
():
a
=
torch
.
rand
([
1
,
2
,
3
])
b
=
torch
.
rand
([
1
,
2
,
3
])
c
=
torch
.
rand
([
1
,
2
,
3
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
[
a
,
b
]},
{
"image"
:
[
c
]}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
[
torch
.
stack
([
a
,
b
]),
c
.
unsqueeze
(
0
)]})
def
test_multimodal_input_batch_multiple_batchable_lists
():
a
=
torch
.
rand
([
1
,
2
,
3
])
b
=
torch
.
rand
([
1
,
2
,
3
])
c
=
torch
.
rand
([
1
,
2
,
3
])
d
=
torch
.
rand
([
1
,
2
,
3
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
[
a
,
b
]},
{
"image"
:
[
c
,
d
]}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
torch
.
stack
([
torch
.
stack
([
a
,
b
]),
torch
.
stack
([
c
,
d
])])})
vllm/model_executor/models/blip2.py
View file @
fab5f53e
...
...
@@ -555,6 +555,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
Blip2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
@@ -564,6 +567,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
Blip2ImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
...
...
vllm/model_executor/models/chameleon.py
View file @
fab5f53e
...
...
@@ -946,6 +946,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
ChameleonImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
vllm/model_executor/models/fuyu.py
View file @
fab5f53e
...
...
@@ -249,6 +249,9 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
image_patches
=
kwargs
.
pop
(
"image_patches"
,
None
)
if
isinstance
(
image_patches
,
torch
.
Tensor
):
# Remove the N dimension until multiple images are supported.
image_patches
=
image_patches
.
squeeze
(
1
)
expected_feature_size
=
self
.
image_feature_size
if
image_patches
.
size
(
-
1
)
!=
expected_feature_size
:
raise
ValueError
(
...
...
vllm/model_executor/models/internvl.py
View file @
fab5f53e
...
...
@@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
...
...
@@ -410,6 +412,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Flatten the B and N dimensions
image_embeds
=
image_embeds
.
flatten
(
0
,
2
)
return
InternVLImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
...
...
@@ -422,6 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Flatten the B and N dimensions
pixel_values
=
pixel_values
.
flatten
(
0
,
2
)
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
vllm/model_executor/models/llava.py
View file @
fab5f53e
...
...
@@ -232,6 +232,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
@@ -241,6 +245,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
LlavaImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
...
...
vllm/model_executor/models/llava_next.py
View file @
fab5f53e
...
...
@@ -361,6 +361,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of image sizes. "
f
"Got type:
{
type
(
image_sizes
)
}
"
)
# Remove the N dimension until multiple images are supported.
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
pixel_values
=
pixel_values
.
squeeze
(
1
)
else
:
pixel_values
=
[
t
.
squeeze
(
0
)
for
t
in
pixel_values
]
image_sizes
=
image_sizes
.
squeeze
(
1
)
return
LlavaNextImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
@@ -372,6 +380,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of image embeds. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
LlavaNextImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
...
...
vllm/model_executor/models/minicpmv.py
View file @
fab5f53e
...
...
@@ -594,9 +594,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
b
in
range
(
len
(
pixel_values
)):
pixel_values_flat
+=
pixel_values
[
b
]
tgt_sizes_flat
+=
tgt_sizes
[
b
]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
pixel_values_flat
+=
pixel_n
tgt_sizes_flat
+=
tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
...
...
vllm/model_executor/models/paligemma.py
View file @
fab5f53e
...
...
@@ -185,6 +185,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
PaliGemmaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
@@ -194,6 +198,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
PaliGemmaImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
image_embeds
,
...
...
vllm/model_executor/models/phi3v.py
View file @
fab5f53e
...
...
@@ -560,6 +560,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of image sizes. "
f
"Got type:
{
type
(
image_sizes
)
}
"
)
# Merge the B and N dimensions.
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
pixel_values
=
pixel_values
.
flatten
(
0
,
1
)
else
:
pixel_values
=
torch
.
cat
(
pixel_values
)
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
return
Phi3VImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
...
...
vllm/model_executor/models/ultravox.py
View file @
fab5f53e
...
...
@@ -333,6 +333,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of audio features. "
f
"Got type:
{
type
(
audio_features
)
}
"
)
# Remove the N dimension until multiple audios are supported.
if
isinstance
(
audio_features
,
torch
.
Tensor
):
audio_features
=
audio_features
.
squeeze
(
1
)
else
:
audio_features
=
[
t
.
squeeze
(
0
)
for
t
in
audio_features
]
return
UltravoxAudioFeatureInputs
(
type
=
"audio_features"
,
data
=
audio_features
)
...
...
@@ -341,6 +347,9 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of audio embeds. "
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
# Remove the N dimension until multiple audios are supported.
audio_embeds
=
audio_embeds
.
squeeze
(
1
)
return
UltravoxAudioEmbeddingInputs
(
type
=
"audio_embeds"
,
data
=
audio_embeds
)
...
...
vllm/model_executor/models/utils.py
View file @
fab5f53e
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
torch.func
import
functional_call
...
...
@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.multimodal
import
Batch
edTensors
from
vllm.multimodal
.base
import
Nest
edTensors
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -54,9 +55,34 @@ def init_vllm_registered_model(
)
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
"""
Recursively concatenates NestedTensors along any heterogeneously sized
dimensions.
"""
if
isinstance
(
embeddings
,
torch
.
Tensor
):
return
embeddings
return
torch
.
cat
(
tuple
(
_flatten_embeddings
(
t
)
for
t
in
embeddings
))
def
_embedding_count_expression
(
embeddings
:
NestedTensors
)
->
str
:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if
isinstance
(
embeddings
,
torch
.
Tensor
):
return
" x "
.
join
([
str
(
dim
)
for
dim
in
embeddings
.
shape
[:
-
1
]])
return
" + "
.
join
(
_embedding_count_expression
(
inner
)
for
inner
in
embeddings
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
Batch
edTensors
,
multimodal_embeddings
:
Nest
edTensors
,
placeholder_token_id
:
int
)
->
torch
.
Tensor
:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
...
...
@@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
mask
=
(
input_ids
==
placeholder_token_id
)
num_expected_tokens
=
mask
.
sum
()
if
isinstance
(
multimodal_embeddings
,
torch
.
Tensor
):
batch_size
,
batch_tokens
,
*
_
,
embed_dim
=
multimodal_embeddings
.
shape
total_tokens
=
batch_size
*
batch_tokens
if
num_expected_tokens
!=
total_tokens
:
expr
=
f
"
{
batch_size
}
x
{
batch_tokens
}
"
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
total_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
multimodal_embeddings
.
view
(
total_tokens
,
embed_dim
)
else
:
size_per_batch
=
[
t
.
shape
[
0
]
for
t
in
multimodal_embeddings
]
total_tokens
=
sum
(
size_per_batch
)
if
num_expected_tokens
!=
total_tokens
:
expr
=
' + '
.
join
(
map
(
str
,
size_per_batch
))
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
total_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
torch
.
cat
(
multimodal_embeddings
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
*
dims
,
embed_dim
=
flattened
.
shape
num_multimodal_embeddings
=
np
.
prod
(
dims
)
if
num_multimodal_embeddings
!=
num_expected_tokens
:
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
num_multimodal_embeddings
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
flattened
.
view
(
num_expected_tokens
,
embed_dim
)
return
inputs_embeds
...
...
vllm/multimodal/__init__.py
View file @
fab5f53e
from
.base
import
(
BatchedTensorInputs
,
BatchedTensors
,
MultiModalDataBuiltins
,
from
.base
import
(
BatchedTensorInputs
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
from
.registry
import
MultiModalRegistry
...
...
@@ -14,7 +14,6 @@ See also:
__all__
=
[
"BatchedTensorInputs"
,
"BatchedTensors"
,
"MultiModalDataBuiltins"
,
"MultiModalDataDict"
,
"MultiModalInputs"
,
...
...
vllm/multimodal/base.py
View file @
fab5f53e
import
sys
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
from
typing
import
(
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
import
numpy
as
np
import
torch
...
...
@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
json_map_leaves
from
vllm.utils
import
json_map_leaves
logger
=
init_logger
(
__name__
)
NestedTensors
=
Union
[
GenericSequence
[
torch
.
Tensor
],
torch
.
Tensor
]
NestedTensors
=
Union
[
List
[
"Nested
Tensor
s"
],
torch
.
Tensor
]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
BatchedTensors
:
TypeAlias
=
JSONTree
[
torch
.
Tensor
]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
JSONTree
[
torch
.
Tensor
]]
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
NestedTensors
]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
...
...
@@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@
staticmethod
def
_try_
concat
(
tensors
:
List
[
NestedTensors
]
)
->
Batch
edTensors
:
def
_try_
stack
(
nested_
tensors
:
NestedTensors
)
->
Nest
edTensors
:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
Recursively stacks lists of tensors when they all have the same shape.
"""
# may be list rather than tensors
if
isinstance
(
tensors
[
0
],
list
):
return
[[
t
for
t
in
tensor
[
0
]]
for
tensor
in
cast
(
List
[
List
[
torch
.
Tensor
]],
tensors
)]
tensors_
=
cast
(
List
[
torch
.
Tensor
],
tensors
)
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
return
nested_tensors
unbatched_shape
=
tensors_
[
0
].
shape
[
1
:]
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
any
(
isinstance
(
t
,
list
)
for
t
in
stacked
):
return
stacked
for
tensor
in
tensors_
:
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
return
[
tensor
.
squeeze
(
0
)
for
tensor
in
tensors_
]
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
# The tensors have incompatible shapes and can't be stacked.
return
tensors_
return
torch
.
cat
(
tensors_
,
dim
=
0
)
return
torch
.
stack
(
tensors_
)
@
staticmethod
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
])
->
BatchedTensorInputs
:
...
...
@@ -102,7 +91,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists
[
k
].
append
(
v
)
return
{
k
:
MultiModalInputs
.
_try_
concat
(
item_list
)
k
:
MultiModalInputs
.
_try_
stack
(
item_list
)
for
k
,
item_list
in
item_lists
.
items
()
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment