Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5241aa14
Unverified
Commit
5241aa14
authored
Oct 21, 2024
by
Michael Goin
Committed by
GitHub
Oct 21, 2024
Browse files
[Model][Bugfix] Fix batching with multi-image in PixtralHF (#9518)
parent
ec6bd6c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
54 additions
and
17 deletions
+54
-17
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+48
-12
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+6
-5
No files found.
vllm/model_executor/models/llava.py
View file @
5241aa14
...
@@ -287,6 +287,34 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -287,6 +287,34 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return
data
return
data
def
_validate_image_sizes
(
self
,
images
:
List
[
torch
.
Tensor
],
sizes
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
if
not
isinstance
(
sizes
,
list
):
sizes
=
[
sizes
]
total_images
=
sum
(
size
.
numel
()
//
2
for
size
in
sizes
)
if
total_images
!=
len
(
images
):
raise
ValueError
(
"Mismatch in number of images. "
f
"Expected
{
total_images
}
, got
{
len
(
images
)
}
"
)
img_idx
=
0
for
size
in
sizes
:
# Flatten the size tensor to a list of (height, width) pairs
size
=
size
.
view
(
-
1
,
2
).
tolist
()
for
expected_h
,
expected_w
in
size
:
if
img_idx
>=
len
(
images
):
raise
ValueError
(
"Ran out of images before sizes. "
f
"
{
img_idx
}
>=
{
len
(
images
)
}
"
)
img
=
images
[
img_idx
]
if
img
.
shape
[
-
2
:]
!=
(
expected_h
,
expected_w
):
raise
ValueError
(
"Image size mismatch. Expected "
f
"
{
(
expected_h
,
expected_w
)
}
, got
{
img
.
shape
[
-
2
:]
}
"
)
if
img
.
shape
[
-
3
]
!=
3
:
raise
ValueError
(
"Image channel mismatch. Expected 3, "
f
"got
{
img
.
shape
[
-
3
]
}
"
)
img_idx
+=
1
return
images
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
...
@@ -305,20 +333,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -305,20 +333,28 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# so we need to produce a list of tensors
# so we need to produce a list of tensors
if
image_sizes
is
not
None
:
if
image_sizes
is
not
None
:
images
=
pixel_values
images
=
pixel_values
if
isinstance
(
images
,
torch
.
Tensor
):
# if passed as batch take all images
def
flatten_to_3d_tensors
(
item
):
NN
,
N
,
B
,
C
,
W
,
H
=
images
.
shape
if
isinstance
(
item
,
torch
.
Tensor
):
images
=
images
.
reshape
(
NN
*
N
*
B
,
C
,
W
,
H
)
if
item
.
dim
()
>=
3
:
images
=
[
images
[
i
]
for
i
in
range
(
images
.
size
(
0
))]
return
[
t
for
t
in
item
.
view
(
-
1
,
*
item
.
shape
[
-
3
:])]
elif
isinstance
(
images
,
list
):
else
:
# if passed as list flatten lists of tensors
raise
ValueError
(
while
isinstance
(
images
,
list
)
and
len
(
images
)
==
1
:
f
"Unexpected tensor dimension:
{
item
.
dim
()
}
"
)
images
=
images
[
0
]
elif
isinstance
(
item
,
list
):
return
[
# TODO: Add validation based on image_sizes
t
for
subitem
in
item
for
t
in
flatten_to_3d_tensors
(
subitem
)
]
else
:
raise
ValueError
(
f
"Unexpected type:
{
type
(
item
)
}
"
)
# Restructure the batched images into a list of lists of images
images
=
flatten_to_3d_tensors
(
pixel_values
)
return
LlavaImagePixelInputs
(
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
images
,
data
=
self
.
_validate_image_sizes
(
images
,
image_sizes
)
,
)
)
return
LlavaImagePixelInputs
(
return
LlavaImagePixelInputs
(
...
...
vllm/model_executor/models/pixtral.py
View file @
5241aa14
...
@@ -907,17 +907,18 @@ class PixtralHFVisionModel(nn.Module):
...
@@ -907,17 +907,18 @@ class PixtralHFVisionModel(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
pixel_values: tensor of token features for
pixel_values: Each image to be processed will be a separate tensor
all tokens of all images of shape (N_toks, D)
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
Returns:
Returns:
image_features: tensor of token features for
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
all tokens of all images of shape (N_toks, D)
"""
"""
# pass images through initial convolution independently
# pass images through initial convolution independently
patch_embeds_list
=
[
patch_embeds_list
=
[
self
.
patch_conv
(
self
.
patch_conv
(
img
.
unsqueeze
(
0
).
to
(
self
.
dtype
))
img
.
reshape
(
-
1
,
img
.
shape
[
-
3
],
img
.
shape
[
-
2
],
img
.
shape
[
-
1
]).
to
(
self
.
dtype
))
for
img
in
pixel_values
for
img
in
pixel_values
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment