Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ef9baee3
Unverified
Commit
ef9baee3
authored
Aug 28, 2024
by
Cyrus Leung
Committed by
GitHub
Aug 28, 2024
Browse files
[Bugfix][VLM] Fix incompatibility between #7902 and #7230 (#7948)
parent
98c12cff
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
120 additions
and
92 deletions
+120
-92
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+2
-2
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+1
-1
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+15
-31
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+2
-2
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+26
-26
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+2
-2
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+27
-23
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+1
-1
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+42
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+2
-2
No files found.
vllm/model_executor/models/blip2.py
View file @
ef9baee3
...
@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
...
@@ -40,13 +40,13 @@ BLIP2_IMAGE_TOKEN_ID = 50265
class
Blip2ImagePixelInputs
(
TypedDict
):
class
Blip2ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape:
`
(batch_size
* num_images
, num_channels, height, width)
`
"""
class
Blip2ImageEmbeddingInputs
(
TypedDict
):
class
Blip2ImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
...
vllm/model_executor/models/chameleon.py
View file @
ef9baee3
...
@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
...
@@ -53,7 +53,7 @@ CHAMELEON_SEP_TOKEN_ID = 8710
class
ChameleonImagePixelInputs
(
TypedDict
):
class
ChameleonImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
"""Shape: `(batch_size
* num_images
, num_channels, height, width)`"""
def
get_max_chameleon_image_tokens
(
ctx
:
InputContext
):
def
get_max_chameleon_image_tokens
(
ctx
:
InputContext
):
...
...
vllm/model_executor/models/internvl.py
View file @
ef9baee3
...
@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
...
@@ -29,7 +29,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
get_clip_num_patches
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
from
.utils
import
(
filter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
IMG_START
=
'<img>'
IMG_START
=
'<img>'
...
@@ -42,19 +42,17 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
...
@@ -42,19 +42,17 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
class
InternVLImagePixelInputs
(
TypedDict
):
class
InternVLImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
torch
.
Tensor
"""
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
"""
class
InternVLImageEmbeddingInputs
(
TypedDict
):
class
InternVLImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
@@ -357,7 +355,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -357,7 +355,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
return
x
return
x
def
extract_feature
(
self
,
pixel_values
)
:
def
extract_feature
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
vit_embeds
=
self
.
vision_model
(
pixel_values
=
pixel_values
)
vit_embeds
=
self
.
vision_model
(
pixel_values
=
pixel_values
)
vit_embeds
=
vit_embeds
[:,
1
:,
:]
vit_embeds
=
vit_embeds
[:,
1
:,
:]
...
@@ -370,17 +368,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -370,17 +368,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
vit_embeds
=
self
.
mlp1
(
vit_embeds
)
vit_embeds
=
self
.
mlp1
(
vit_embeds
)
return
vit_embeds
return
vit_embeds
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
[
2
]:
raise
ValueError
(
f
"The expected image sizes shape is batch dimension plus "
f
"
{
[
2
]
}
. You supplied
{
data
.
shape
}
."
)
return
data
def
_validate_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
expected_dims
=
(
3
,
h
,
w
)
...
@@ -389,10 +377,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -389,10 +377,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
actual_dims
=
tuple
(
d
.
shape
)
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
)
)
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
raise
ValueError
(
"The expected shape of pixel values in each batch element "
"The expected shape of pixel values per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
f
" per patch is
{
expected_expr
}
. "
f
"You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
for
d
in
data
:
_validate_shape
(
d
)
_validate_shape
(
d
)
...
@@ -413,12 +402,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -413,12 +402,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of image embeddings. "
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Flatten the B and N dimensions
image_embeds
=
image_embeds
.
flatten
(
0
,
2
)
return
InternVLImageEmbeddingInputs
(
return
InternVLImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
image_embeds
,
data
=
flatten_bn
(
image_embeds
)
,
)
)
self
.
img_context_token_id
=
image_token_id
[
0
]
self
.
img_context_token_id
=
image_token_id
[
0
]
...
@@ -428,12 +414,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -428,12 +414,10 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Flatten the B and N dimensions
pixel_values
=
pixel_values
.
flatten
(
0
,
2
)
return
InternVLImagePixelInputs
(
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
).
flatten
(
0
,
1
)),
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/llava.py
View file @
ef9baee3
...
@@ -30,13 +30,13 @@ from .utils import (filter_weights, init_vllm_registered_model,
...
@@ -30,13 +30,13 @@ from .utils import (filter_weights, init_vllm_registered_model,
class
LlavaImagePixelInputs
(
TypedDict
):
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
"""Shape: `(batch_size
* num_images
, num_channels, height, width)`"""
class
LlavaImageEmbeddingInputs
(
TypedDict
):
class
LlavaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
...
vllm/model_executor/models/llava_next.py
View file @
ef9baee3
...
@@ -29,7 +29,7 @@ from .llava import LlavaMultiModalProjector
...
@@ -29,7 +29,7 @@ from .llava import LlavaMultiModalProjector
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
from
.utils
import
(
filter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
...
@@ -47,15 +47,16 @@ class LlavaNextImagePixelInputs(TypedDict):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different
for each batch, in which case
Note that `num_patches` may be different
per batch and image,
the data is passed as a list instead of a batched tensor.
in which case
the data is passed as a list instead of a batched tensor.
"""
"""
image_sizes
:
NotRequired
[
torch
.
Tensor
]
image_sizes
:
NotRequired
[
torch
.
Tensor
]
"""
"""
Shape: `(batch_size, 2)`
Shape: `(batch_size
* num_images
, 2)`
This should be in `(height, width)` format.
This should be in `(height, width)` format.
"""
"""
...
@@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
...
@@ -64,7 +65,7 @@ class LlavaNextImagePixelInputs(TypedDict):
class
LlavaNextImageEmbeddingInputs
(
TypedDict
):
class
LlavaNextImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
@@ -315,10 +316,19 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -315,10 +316,19 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
torch
.
empty
(
config
.
text_config
.
hidden_size
))
torch
.
empty
(
config
.
text_config
.
hidden_size
))
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
[
2
]:
expected_dims
=
(
2
,
)
raise
ValueError
(
f
"The expected image sizes shape is batch dimension plus "
def
_validate_shape
(
d
:
torch
.
Tensor
):
f
"
{
[
2
]
}
. You supplied
{
data
.
shape
}
."
)
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
f
"The expected shape of image sizes per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
return
data
...
@@ -335,7 +345,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -335,7 +345,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
if
actual_dims
!=
expected_dims
:
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
raise
ValueError
(
"The expected shape of pixel values
in each batch element
"
"The expected shape of pixel values
per image per batch
"
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
for
d
in
data
:
...
@@ -357,22 +367,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -357,22 +367,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
image_sizes
,
torch
.
Tensor
):
if
not
isinstance
(
image_sizes
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of image sizes. "
raise
ValueError
(
"Incorrect type of image sizes. "
f
"Got type:
{
type
(
image_sizes
)
}
"
)
f
"Got type:
{
type
(
image_sizes
)
}
"
)
# Remove the N dimension until multiple images are supported.
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
pixel_values
=
pixel_values
.
squeeze
(
1
)
else
:
pixel_values
=
[
t
.
squeeze
(
0
)
for
t
in
pixel_values
]
image_sizes
=
image_sizes
.
squeeze
(
1
)
return
LlavaNextImagePixelInputs
(
return
LlavaNextImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
)),
image_sizes
=
self
.
_validate_image_sizes
(
image_sizes
),
image_sizes
=
self
.
_validate_image_sizes
(
flatten_bn
(
image_sizes
,
concat
=
True
)),
)
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
...
@@ -380,12 +383,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -380,12 +383,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of image embeds. "
raise
ValueError
(
"Incorrect type of image embeds. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
LlavaNextImageEmbeddingInputs
(
return
LlavaNextImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
image_embeds
,
data
=
flatten_bn
(
image_embeds
)
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/paligemma.py
View file @
ef9baee3
...
@@ -34,13 +34,13 @@ _KEYS_TO_MODIFY_MAPPING = {
...
@@ -34,13 +34,13 @@ _KEYS_TO_MODIFY_MAPPING = {
class
PaliGemmaImagePixelInputs
(
TypedDict
):
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
"""Shape:
`
(batch_size
* num_images
, num_channels, height, width)
`
"""
class
PaliGemmaImageEmbeddingInputs
(
TypedDict
):
class
PaliGemmaImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
data
:
torch
.
Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
...
vllm/model_executor/models/phi3v.py
View file @
ef9baee3
...
@@ -44,7 +44,7 @@ from vllm.utils import is_list_of
...
@@ -44,7 +44,7 @@ from vllm.utils import is_list_of
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
flatten_bn
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict):
...
@@ -75,15 +75,16 @@ class Phi3VImagePixelInputs(TypedDict):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different
for each batch, in which case
Note that `num_patches` may be different
per batch and image,
the data is passed as a list instead of a batched tensor.
in which case
the data is passed as a list instead of a batched tensor.
"""
"""
image_sizes
:
torch
.
Tensor
image_sizes
:
torch
.
Tensor
"""
"""
Shape: `(batch_size, 2)`
Shape: `(batch_size
* num_images
, 2)`
This should be in `(height, width)` format.
This should be in `(height, width)` format.
"""
"""
...
@@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict):
...
@@ -92,7 +93,7 @@ class Phi3VImagePixelInputs(TypedDict):
class
Phi3VImageEmbeddingInputs
(
TypedDict
):
class
Phi3VImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
type
:
Literal
[
"image_embeds"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
"""Shape: `(batch_size
* num_images
, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
`hidden_size` must match the hidden size of language model backbone.
"""
"""
...
@@ -511,10 +512,19 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
...
@@ -511,10 +512,19 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_validate_image_sizes
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
[
2
]:
expected_dims
=
(
2
,
)
raise
ValueError
(
f
"The expected shape of image sizes is batch dimension plus "
def
_validate_shape
(
d
:
torch
.
Tensor
):
f
"
{
[
2
]
}
. You supplied
{
tuple
(
data
.
shape
)
}
."
)
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
str
(
expected_dims
)
raise
ValueError
(
f
"The expected shape of image sizes per image per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
return
data
...
@@ -531,7 +541,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
...
@@ -531,7 +541,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
if
actual_dims
!=
expected_dims
:
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
expected_expr
=
(
"num_patches"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
raise
ValueError
(
"The expected shape of pixel values
in each batch element
"
"The expected shape of pixel values
per image per batch
"
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
for
d
in
data
:
...
@@ -556,30 +566,24 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
...
@@ -556,30 +566,24 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
image_sizes
,
torch
.
Tensor
):
if
not
isinstance
(
image_sizes
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of image sizes. "
raise
ValueError
(
"Incorrect type of image sizes. "
f
"Got type:
{
type
(
image_sizes
)
}
"
)
f
"Got type:
{
type
(
image_sizes
)
}
"
)
# Merge the B and N dimensions.
if
isinstance
(
pixel_values
,
torch
.
Tensor
):
pixel_values
=
pixel_values
.
flatten
(
0
,
1
)
else
:
pixel_values
=
torch
.
cat
(
pixel_values
)
image_sizes
=
image_sizes
.
flatten
(
0
,
1
)
return
Phi3VImagePixelInputs
(
return
Phi3VImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
)),
image_sizes
=
self
.
_validate_image_sizes
(
image_sizes
))
image_sizes
=
self
.
_validate_image_sizes
(
flatten_bn
(
image_sizes
,
concat
=
True
)))
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
raise
ValueError
(
"Incorrect type of image embeddings. "
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
f
"Got type:
{
type
(
image_embeds
)
}
"
)
return
Phi3VImageEmbeddingInputs
(
return
Phi3VImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
image_embeds
,
data
=
flatten_bn
(
image_embeds
)
,
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/ultravox.py
View file @
ef9baee3
...
@@ -49,7 +49,7 @@ logger = init_logger(__name__)
...
@@ -49,7 +49,7 @@ logger = init_logger(__name__)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""Shape: `(batch_size, 80, M)"""
"""Shape: `(batch_size
* num_audios
, 80, M)"""
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
...
...
vllm/model_executor/models/utils.py
View file @
ef9baee3
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -55,6 +56,44 @@ def init_vllm_registered_model(
...
@@ -55,6 +56,44 @@ def init_vllm_registered_model(
)
)
@
overload
def
flatten_bn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@
overload
def
flatten_bn
(
x
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
...
@
overload
def
flatten_bn
(
x
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
*
,
concat
:
Literal
[
True
],
)
->
torch
.
Tensor
:
...
def
flatten_bn
(
x
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
*
,
concat
:
bool
=
False
,
)
->
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if
isinstance
(
x
,
torch
.
Tensor
):
return
x
.
flatten
(
0
,
1
)
if
concat
:
return
torch
.
cat
(
x
)
return
[
x_n
for
x_b
in
x
for
x_n
in
x_b
]
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
"""
"""
Recursively concatenates NestedTensors along any heterogeneously sized
Recursively concatenates NestedTensors along any heterogeneously sized
...
@@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
...
@@ -93,7 +132,8 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place.
This updates ``inputs_embeds`` in place.
"""
"""
mask
=
(
input_ids
==
placeholder_token_id
)
mask
=
(
input_ids
==
placeholder_token_id
)
num_expected_tokens
=
mask
.
sum
()
num_expected_tokens
=
mask
.
sum
().
item
()
assert
isinstance
(
num_expected_tokens
,
int
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
*
dims
,
embed_dim
=
flattened
.
shape
*
dims
,
embed_dim
=
flattened
.
shape
...
...
vllm/multimodal/base.py
View file @
ef9baee3
...
@@ -18,7 +18,7 @@ from vllm.utils import JSONTree, is_list_of, json_map_leaves
...
@@ -18,7 +18,7 @@ from vllm.utils import JSONTree, is_list_of, json_map_leaves
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
torch
.
Tensor
]
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
List
[
torch
.
Tensor
],
torch
.
Tensor
]
"""
"""
Uses a list instead of a tensor if the dimensions of each element do not match.
Uses a list instead of a tensor if the dimensions of each element do not match.
"""
"""
...
@@ -61,7 +61,7 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -61,7 +61,7 @@ class MultiModalInputs(_MultiModalInputsBase):
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
# The tensors have incompatible shapes and can't be stacked.
# The tensors have incompatible shapes and can't be stacked.
return
stacked
return
tensors_
return
torch
.
stack
(
tensors_
)
return
torch
.
stack
(
tensors_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment