Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6287537a
Unverified
Commit
6287537a
authored
May 20, 2024
by
Cyrus Leung
Committed by
GitHub
May 20, 2024
Browse files
[Model] LLaVA model refactor (#4910)
parent
b57e6c59
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
107 additions
and
30 deletions
+107
-30
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+107
-30
No files found.
vllm/model_executor/models/llava.py
View file @
6287537a
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
,
Union
import
torch
from
torch
import
nn
...
...
@@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
return
inputs_embeds
class
LlavaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class
LlavaImageFeatureInputs
(
TypedDict
):
type
:
Literal
[
"image_features"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""
LlavaImageInputs
=
Union
[
LlavaImagePixelInputs
,
LlavaImageFeatureInputs
]
class
LlavaForConditionalGeneration
(
VisionLanguageModelBase
):
def
__init__
(
self
,
...
...
@@ -102,6 +117,90 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
_validate_image_data
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
list
(
data
.
shape
[
1
:])
!=
list
(
self
.
vision_language_config
.
image_input_shape
[
1
:]):
raise
ValueError
(
f
"The expected image tensor shape is batch dimension plus "
f
"
{
self
.
vision_language_config
.
image_input_shape
[
1
:]
}
. "
f
"You supplied
{
data
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
f
"image_input_shape in engine args."
)
return
data
def
_parse_and_validate_image_input
(
self
,
data
:
object
)
->
Optional
[
LlavaImageInputs
]:
expected_input_type
=
self
.
vision_language_config
.
image_input_type
ImageInputType
=
VisionLanguageConfig
.
ImageInputType
if
data
is
None
:
return
None
if
expected_input_type
==
ImageInputType
.
PIXEL_VALUES
:
if
not
isinstance
(
data
,
torch
.
Tensor
):
raise
TypeError
(
"Image pixel vector should be a tensor, "
f
"but received type:
{
type
(
data
)
}
"
)
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_image_data
(
data
),
)
elif
expected_input_type
==
ImageInputType
.
IMAGE_FEATURES
:
if
not
isinstance
(
data
,
torch
.
Tensor
):
raise
TypeError
(
"Image feature vector should be a tensor, "
f
"but received type:
{
type
(
data
)
}
"
)
return
LlavaImageFeatureInputs
(
type
=
"image_features"
,
data
=
self
.
_validate_image_data
(
data
),
)
return
None
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if
strategy
==
"default"
:
return
image_features
[:,
1
:]
elif
strategy
==
"full"
:
return
image_features
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
CLIPVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs
=
vision_tower
(
pixel_values
.
to
(
vision_tower
.
device
),
output_hidden_states
=
True
)
image_features
=
image_outputs
.
hidden_states
[
self
.
config
.
vision_feature_layer
]
return
self
.
_select_image_features
(
image_features
,
strategy
=
self
.
config
.
vision_feature_select_strategy
,
)
def
_process_image_pixels
(
self
,
inputs
:
LlavaImagePixelInputs
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
def
_process_image_input
(
self
,
image_input
:
LlavaImageInputs
)
->
torch
.
Tensor
:
if
image_input
[
"type"
]
==
"pixel_values"
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
else
:
image_features
=
image_input
[
"data"
]
return
self
.
multi_modal_projector
(
image_features
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
@@ -144,42 +243,20 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
"""
if
image_input
is
not
None
:
if
list
(
image_input
.
shape
[
1
:])
!=
list
(
self
.
vision_language_config
.
image_input_shape
[
1
:]):
raise
ValueError
(
f
"The expected image tensor shape is batch dimension "
f
"plus "
f
"
{
self
.
vision_language_config
.
image_input_shape
[
1
:]
}
."
f
" You supplied
{
image_input
.
shape
}
. "
f
"If you are using vLLM's entrypoint, make sure your "
f
"supplied image input is consistent with "
f
"image_input_shape in engine args."
)
if
self
.
vision_tower
is
not
None
:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs
=
self
.
vision_tower
(
image_input
,
output_hidden_states
=
True
)
image_features
=
image_outputs
.
hidden_states
[
self
.
config
.
vision_feature_layer
]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if
self
.
config
.
vision_feature_select_strategy
==
"default"
:
image_features
=
image_features
[:,
1
:]
elif
self
.
config
.
vision_feature_select_strategy
==
"full"
:
image_features
=
image_features
else
:
raise
ValueError
(
f
"Unexpected select feature strategy: "
f
"
{
self
.
config
.
vision_feature_select_strategy
}
"
)
else
:
image_features
=
image_input
vision_embeddings
=
self
.
multi_modal_projector
(
image_features
)
parsed_image_input
=
self
.
_parse_and_validate_image_input
(
image_input
)
if
parsed_image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
parsed_image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
_merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
kv_caches
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment