Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f68470e8
Unverified
Commit
f68470e8
authored
May 19, 2024
by
Cyrus Leung
Committed by
GitHub
May 19, 2024
Browse files
[Bugfix][Model] Add base class for vision-language models (#4809)
parent
2e9a2227
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
29 deletions
+53
-29
tests/models/test_registry.py
tests/models/test_registry.py
+9
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+7
-6
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+25
-23
vllm/model_executor/models/vlm_base.py
vllm/model_executor/models/vlm_base.py
+12
-0
No files found.
tests/models/test_registry.py
0 → 100644
View file @
f68470e8
import
pytest
from
vllm.model_executor.models
import
_MODELS
,
ModelRegistry
@
pytest
.
mark
.
parametrize
(
"model_cls"
,
_MODELS
)
def
test_registry_imports
(
model_cls
):
# Ensure all model classes can be imported successfully
ModelRegistry
.
load_model_cls
(
model_cls
)
vllm/model_executor/model_loader/loader.py
View file @
f68470e8
...
@@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -26,11 +26,7 @@ from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf
,
filter_files_not_needed_for_inference
,
download_weights_from_hf
,
filter_files_not_needed_for_inference
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
get_quant_config
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
)
pt_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.models.llava
import
LlavaForConditionalGeneration
from
vllm.model_executor.models.vlm_base
import
VisionLanguageModelBase
_VISION_MODEL_CLASSES
=
[
LlavaForConditionalGeneration
,
]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
...
@@ -73,7 +69,12 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
"please open an issue on github."
)
elif
model_class
in
_VISION_MODEL_CLASSES
:
elif
issubclass
(
model_class
,
VisionLanguageModelBase
):
if
vision_language_config
is
None
:
raise
ValueError
(
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)
extra_kwargs
[
"vision_language_config"
]
=
vision_language_config
extra_kwargs
[
"vision_language_config"
]
=
vision_language_config
return
extra_kwargs
return
extra_kwargs
...
...
vllm/model_executor/models/llava.py
View file @
f68470e8
...
@@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
...
@@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
from
.vlm_base
import
VisionLanguageModelBase
_KEYS_TO_MODIFY_MAPPING
=
{
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
"language_model.model"
:
"language_model"
,
...
@@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
...
@@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
bias
=
True
)
def
forward
(
self
,
image_features
)
:
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
...
@@ -50,29 +52,31 @@ class LlavaMultiModalProjector(nn.Module):
...
@@ -50,29 +52,31 @@ class LlavaMultiModalProjector(nn.Module):
def
_merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
def
_merge_vision_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
vision_embeddings
:
torch
.
Tensor
,
image_token_id
:
int
):
image_token_id
:
int
)
->
torch
.
Tensor
:
"""In place merges in vision_embeddings with inputs_embeds."""
"""In place merges in vision_embeddings with inputs_embeds."""
mask
=
(
input_ids
==
image_token_id
)
mask
=
(
input_ids
==
image_token_id
)
inputs_embeds
[
mask
]
=
vision_embeddings
.
view
(
-
1
,
image_feature_size
=
vision_embeddings
.
shape
[
0
]
*
vision_embeddings
.
shape
[
1
]
if
mask
.
sum
()
!=
image_feature_size
:
raise
ValueError
(
f
"image_feature_size should be
{
image_feature_size
}
, "
f
"but found:
{
mask
.
sum
()
}
"
)
inputs_embeds
[
mask
]
=
vision_embeddings
.
view
(
image_feature_size
,
vision_embeddings
.
shape
[
-
1
])
vision_embeddings
.
shape
[
-
1
])
return
inputs_embeds
class
LlavaForConditionalGeneration
(
nn
.
Module
):
class
LlavaForConditionalGeneration
(
VisionLanguageModelBase
):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
"
LlavaConfig
"
,
config
:
LlavaConfig
,
vision_language_config
:
VisionLanguageConfig
,
vision_language_config
:
VisionLanguageConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
"QuantizationConfig"
]
=
None
)
->
None
:
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
(
vision_language_config
)
self
.
config
=
config
self
.
vision_language_config
=
vision_language_config
assert
self
.
vision_language_config
,
(
self
.
config
=
config
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)
if
self
.
vision_language_config
.
image_input_type
==
(
if
self
.
vision_language_config
.
image_input_type
==
(
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
VisionLanguageConfig
.
ImageInputType
.
PIXEL_VALUES
):
...
@@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
...
@@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
config
.
vocab_size
,
logit_scale
)
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
image_input
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
image_input
:
Optional
[
torch
.
Tensor
]
=
None
)
->
SamplerOutput
:
# noqa: E501
"""Run forward pass for Llava 1.5.
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
One key thing to understand is the `input_ids` already accounts for the
...
@@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
...
@@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
image_features
=
image_input
image_features
=
image_input
vision_embeddings
=
self
.
multi_modal_projector
(
image_features
)
vision_embeddings
=
self
.
multi_modal_projector
(
image_features
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
_merge_vision_embeddings
(
inputs_embeds
=
_merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_language_config
.
image_token_id
)
self
.
vision_language_config
.
image_token_id
)
input_ids
=
None
input_ids
=
None
...
...
vllm/model_executor/models/vlm_base.py
0 → 100644
View file @
f68470e8
from
torch
import
nn
from
vllm.config
import
VisionLanguageConfig
class
VisionLanguageModelBase
(
nn
.
Module
):
"""Base class for all vision language models (VLMs)."""
def
__init__
(
self
,
vision_language_config
:
VisionLanguageConfig
)
->
None
:
super
().
__init__
()
self
.
vision_language_config
=
vision_language_config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment