Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06ed2815
Unverified
Commit
06ed2815
authored
Sep 22, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 22, 2024
Browse files
[Model] Refactor BLIP/BLIP-2 to support composite model loading (#8407)
parent
0e40ac9b
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
113 additions
and
114 deletions
+113
-114
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+58
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+47
-74
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+0
-3
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+4
-7
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+0
-3
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+0
-8
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+0
-3
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+0
-3
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+4
-7
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+0
-3
No files found.
vllm/model_executor/models/blip.py
View file @
06ed2815
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from
typing
import
Optional
,
Union
from
typing
import
Iterable
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
SequenceData
...
...
@@ -342,6 +343,10 @@ class BlipVisionModel(nn.Module):
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
config
=
config
self
.
embeddings
=
BlipVisionEmbeddings
(
config
)
...
...
@@ -350,11 +355,61 @@ class BlipVisionModel(nn.Module):
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
)
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
config
.
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
elif
len
(
self
.
encoder
.
layers
)
==
config
.
num_hidden_layers
:
self
.
post_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
else
:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self
.
post_layernorm
=
None
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
if
self
.
post_layernorm
is
None
:
return
hidden_states
return
self
.
post_layernorm
(
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
# post_layernorm is not needed in BlipVisionModel
if
(
name
.
startswith
(
"post_layernorm"
)
and
self
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
name
.
startswith
(
"encoder.layers"
):
layer_idx
=
int
(
name
.
split
(
"."
)[
2
])
if
layer_idx
>=
layer_count
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/blip2.py
View file @
06ed2815
...
...
@@ -10,11 +10,9 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
...
...
@@ -22,12 +20,8 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
get_max_blip_image_tokens
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
from
.utils
import
(
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
...
...
@@ -491,9 +485,6 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
super
().
__init__
()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
...
...
@@ -514,17 +505,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
bias
=
True
,
)
self
.
quant_config
=
quant_config
self
.
language_model
=
OPTModel
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
)
self
.
sampler
=
Sampler
()
def
get_lm_head
(
self
):
return
self
.
language_model
.
decoder
.
embed_tokens
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
def
_validate_pixel_values
(
self
,
data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h
=
w
=
self
.
config
.
vision_config
.
image_size
...
...
@@ -653,7 +635,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
...
...
@@ -663,7 +646,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
...
...
@@ -676,56 +659,46 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
get_lm_head
(),
hidden_states
,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# only doing this for language model part for now.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
if
"lm_head.weight"
in
name
:
continue
if
"rotary_emb.inv_freq"
in
name
:
continue
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
if
key_to_modify
in
name
:
name
=
name
.
replace
(
key_to_modify
,
new_key
)
use_default_weight_loading
=
False
if
"vision"
in
name
:
if
self
.
vision_model
is
not
None
:
# BlipVisionModel does not need sharding
use_default_weight_loading
=
True
else
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
# prepare weight iterators for components
weights_group
=
group_weights_with_prefix
(
weights
)
# load vision encoder
self
.
vision_model
.
load_weights
(
weights_group
[
"vision_model"
])
# load query tokens
for
name
,
loaded_weight
in
weights_group
[
"query_tokens"
]:
assert
name
==
""
param
=
self
.
query_tokens
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load qformer
qformer_params_dict
=
dict
(
self
.
qformer
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"qformer"
]:
param
=
qformer_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load mlp projector
mlp_params_dict
=
dict
(
self
.
language_projection
.
named_parameters
())
for
name
,
loaded_weight
in
weights_group
[
"language_projection"
]:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
self
.
language_model
.
load_weights
(
weights_group
[
"language_model"
])
vllm/model_executor/models/chameleon.py
View file @
06ed2815
...
...
@@ -12,7 +12,6 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -36,8 +35,6 @@ from vllm.utils import print_warning_once
from
.interfaces
import
SupportsMultiModal
logger
=
init_logger
(
__name__
)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
CHAMELEON_CROP_SIZE_HEIGHT
=
CHAMELEON_CROP_SIZE_WIDTH
=
512
...
...
vllm/model_executor/models/clip.py
View file @
06ed2815
...
...
@@ -391,6 +391,7 @@ class CLIPVisionModel(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
):
super
().
__init__
()
tp_size
=
get_tensor_model_parallel_world_size
()
num_heads
=
config
.
num_attention_heads
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
...
...
@@ -400,10 +401,6 @@ class CLIPVisionModel(nn.Module):
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
)
@
property
def
_require_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
post_layernorm
is
not
None
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
pixel_values
)
...
...
@@ -425,12 +422,12 @@ class CLIPVisionModel(nn.Module):
for
name
,
loaded_weight
in
weights
:
# post_layernorm is not needed in CLIPVisionModel
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
_require_
post_layernorm
):
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers
."
in
name
:
if
name
.
startswith
(
"vision_model.encoder.layers
"
)
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
...
...
vllm/model_executor/models/fuyu.py
View file @
06ed2815
...
...
@@ -28,7 +28,6 @@ from transformers import FuyuConfig, FuyuImageProcessor
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
...
@@ -45,8 +44,6 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID
=
71011
_NEWLINE_TOKEN_ID
=
71019
...
...
vllm/model_executor/models/llava_next.py
View file @
06ed2815
...
...
@@ -12,7 +12,6 @@ from typing_extensions import NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -32,13 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from
.utils
import
(
flatten_bn
,
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.lm_head"
:
"lm_head"
,
"language_model.model"
:
"language_model"
,
}
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
...
...
vllm/model_executor/models/llava_next_video.py
View file @
06ed2815
...
...
@@ -11,7 +11,6 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
...
...
@@ -32,8 +31,6 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
from
.utils
import
(
group_weights_with_prefix
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
# For profile run
_MAX_FRAMES_PER_VIDEO
=
32
_MAX_NUM_VIDEOS
=
1
...
...
vllm/model_executor/models/minicpmv.py
View file @
06ed2815
...
...
@@ -37,7 +37,6 @@ from transformers import PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -59,8 +58,6 @@ from vllm.sequence import IntermediateTensors, SequenceData
from
.idefics2_vision_model
import
Idefics2VisionTransformer
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"llm.lm_head"
:
"lm_head"
,
"llm.model"
:
"llm"
,
...
...
vllm/model_executor/models/siglip.py
View file @
06ed2815
...
...
@@ -501,6 +501,7 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
...
...
@@ -511,10 +512,6 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
=
num_hidden_layers_override
,
)
@
property
def
_require_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
post_layernorm
is
not
None
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
...
...
@@ -540,12 +537,12 @@ class SiglipVisionModel(nn.Module):
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in SiglipVisionModel
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
_require_
post_layernorm
):
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers
."
in
name
:
if
name
.
startswith
(
"vision_model.encoder.layers
"
)
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
...
...
vllm/model_executor/models/ultravox.py
View file @
06ed2815
...
...
@@ -20,7 +20,6 @@ from vllm.config import CacheConfig, MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
from
vllm.inputs.data
import
LLMInputs
from
vllm.inputs.registry
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -43,8 +42,6 @@ from vllm.transformers_utils.configs.ultravox import UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_TOKENS_PER_SECOND
=
6.25
logger
=
init_logger
(
__name__
)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment