Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
74d5543e
Unverified
Commit
74d5543e
authored
Aug 28, 2024
by
Peter Salas
Committed by
GitHub
Aug 29, 2024
Browse files
[VLM][Core] Fix exceptions on ragged NestedTensors (#7974)
parent
a7f65c2b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
11 deletions
+21
-11
tests/multimodal/test_base.py
tests/multimodal/test_base.py
+12
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+7
-9
vllm/multimodal/base.py
vllm/multimodal/base.py
+2
-2
No files found.
tests/multimodal/test_base.py
View file @
74d5543e
...
@@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
...
@@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
result
,
result
,
{
"image"
:
torch
.
stack
([
torch
.
stack
([
a
,
b
]),
{
"image"
:
torch
.
stack
([
torch
.
stack
([
a
,
b
]),
torch
.
stack
([
c
,
d
])])})
torch
.
stack
([
c
,
d
])])})
def
test_multimodal_input_batch_mixed_stacking_depths
():
a
=
torch
.
rand
([
1
,
2
,
3
])
b
=
torch
.
rand
([
1
,
3
,
3
])
c
=
torch
.
rand
([
1
,
4
,
3
])
result
=
MultiModalInputs
.
batch
([{
"image"
:
[
a
,
b
]},
{
"image"
:
[
c
]}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
[[
a
,
b
],
c
.
unsqueeze
(
0
)]})
result
=
MultiModalInputs
.
batch
([{
"image"
:
[
a
]},
{
"image"
:
[
b
,
c
]}])
assert_multimodal_inputs_equal
(
result
,
{
"image"
:
[
a
.
unsqueeze
(
0
),
[
b
,
c
]]})
vllm/model_executor/models/utils.py
View file @
74d5543e
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
Union
,
overload
)
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.func
import
functional_call
from
torch.func
import
functional_call
...
@@ -96,12 +95,13 @@ def flatten_bn(
...
@@ -96,12 +95,13 @@ def flatten_bn(
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
"""
"""
Recursively concatenates NestedTensors
al
on
g
a
ny heterogeneously sized
Recursively
flattens and
concatenates NestedTensors on a
ll but the last
dimension
s
.
dimension.
"""
"""
if
isinstance
(
embeddings
,
torch
.
Tensor
):
if
isinstance
(
embeddings
,
torch
.
Tensor
):
return
embeddings
# Flatten all but the last dimension.
return
embeddings
.
flatten
(
0
,
-
2
)
return
torch
.
cat
(
tuple
(
_flatten_embeddings
(
t
)
for
t
in
embeddings
))
return
torch
.
cat
(
tuple
(
_flatten_embeddings
(
t
)
for
t
in
embeddings
))
...
@@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
...
@@ -136,15 +136,13 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
assert
isinstance
(
num_expected_tokens
,
int
)
assert
isinstance
(
num_expected_tokens
,
int
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
*
dims
,
embed_dim
=
flattened
.
shape
if
flattened
.
shape
[
0
]
!=
num_expected_tokens
:
num_multimodal_embeddings
=
np
.
prod
(
dims
)
if
num_multimodal_embeddings
!=
num_expected_tokens
:
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
raise
ValueError
(
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
num_multimodal_embeddings
}
"
f
"Attempted to assign
{
expr
}
=
{
flattened
.
shape
[
0
]
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
flattened
.
view
(
num_expected_tokens
,
embed_dim
)
inputs_embeds
[
mask
]
=
flattened
return
inputs_embeds
return
inputs_embeds
...
...
vllm/multimodal/base.py
View file @
74d5543e
...
@@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -54,8 +54,8 @@ class MultiModalInputs(_MultiModalInputsBase):
return
nested_tensors
return
nested_tensors
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
is_list_of
(
stacked
,
list
):
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
#
Do not stack nested lists
#
Only tensors (not lists) can be stacked.
return
stacked
return
stacked
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment