Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4851c202
Commit
4851c202
authored
Sep 13, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1' into v0.6.1-dev
parents
9b902f9e
3fd2b0d2
Changes
203
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3088 additions
and
341 deletions
+3088
-341
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+26
-5
vllm/model_executor/models/commandr.py
vllm/model_executor/models/commandr.py
+3
-2
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+38
-14
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+52
-17
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+4
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+51
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+20
-12
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+3
-1
vllm/model_executor/models/llava_next_video.py
vllm/model_executor/models/llava_next_video.py
+471
-0
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+5
-155
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+7
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+35
-77
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+3
-1
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+551
-0
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+667
-26
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+8
-2
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+1088
-0
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+37
-20
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+16
-0
vllm/multimodal/base.py
vllm/multimodal/base.py
+3
-5
No files found.
vllm/model_executor/models/clip.py
View file @
4851c202
...
@@ -105,7 +105,7 @@ def input_processor_for_clip(
...
@@ -105,7 +105,7 @@ def input_processor_for_clip(
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
image_feature_size
=
get_clip_image_feature_size
(
hf_config
)
image_feature_size
=
get_clip_image_feature_size
(
hf_config
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
else
:
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
else
:
else
:
...
@@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module):
...
@@ -355,6 +355,19 @@ class CLIPVisionTransformer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
)
num_hidden_layers_override
=
num_hidden_layers_override
)
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
config
.
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
elif
len
(
self
.
encoder
.
layers
)
==
config
.
num_hidden_layers
:
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
else
:
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self
.
post_layernorm
=
None
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
torch
.
Tensor
,
pixel_values
:
torch
.
Tensor
,
...
@@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module):
...
@@ -364,7 +377,10 @@ class CLIPVisionTransformer(nn.Module):
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
hidden_states
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
hidden_states
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
return
hidden_states
if
self
.
post_layernorm
is
None
:
return
hidden_states
return
self
.
post_layernorm
(
hidden_states
)
class
CLIPVisionModel
(
nn
.
Module
):
class
CLIPVisionModel
(
nn
.
Module
):
...
@@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module):
...
@@ -386,9 +402,12 @@ class CLIPVisionModel(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
)
num_hidden_layers_override
=
num_hidden_layers_override
)
def
forward
(
self
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
):
@
property
def
_require_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
post_layernorm
is
not
None
return
self
.
vision_model
(
pixel_values
=
pixel_values
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
pixel_values
)
@
property
@
property
def
device
(
self
):
def
device
(
self
):
...
@@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module):
...
@@ -408,8 +427,10 @@ class CLIPVisionModel(nn.Module):
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# post_layernorm is not needed in CLIPVisionModel
# post_layernorm is not needed in CLIPVisionModel
if
"vision_model.post_layernorm"
in
name
:
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
_require_post_layernorm
):
continue
continue
# omit layers when num_hidden_layers_override is set
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers."
in
name
:
if
"vision_model.encoder.layers."
in
name
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
...
...
vllm/model_executor/models/commandr.py
View file @
4851c202
...
@@ -47,6 +47,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -47,6 +47,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
@
torch
.
compile
@
torch
.
compile
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
def
layer_norm_func
(
hidden_states
,
weight
,
variance_epsilon
):
...
@@ -292,8 +294,7 @@ class CohereModel(nn.Module):
...
@@ -292,8 +294,7 @@ class CohereModel(nn.Module):
return
hidden_states
return
hidden_states
class
CohereForCausalLM
(
nn
.
Module
):
class
CohereForCausalLM
(
nn
.
Module
,
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
...
vllm/model_executor/models/internlm2.py
View file @
4851c202
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
...
@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
InternLM2MLP
(
nn
.
Module
):
class
InternLM2MLP
(
nn
.
Module
):
...
@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
...
@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
...
@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
InternLMDecoderLayer
(
config
,
cache_config
,
quant_config
)
config
.
num_hidden_layers
,
for
_
in
range
(
config
.
num_hidden_layers
)
lambda
prefix
:
InternLMDecoderLayer
(
config
,
cache_config
,
])
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
tok_embeddings
(
input_ids
)
return
self
.
tok_embeddings
(
input_ids
)
...
@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
...
@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
IntermediateTensors
=
None
,
intermediate_tensors
:
IntermediateTensors
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
inputs_embeds
is
not
None
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
inputs_embeds
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
residual
=
None
else
:
else
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
assert
intermediate_tensors
is
not
None
residual
=
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
intermediate_tensors
:
IntermediateTensors
,
intermediate_tensors
:
IntermediateTensors
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/internvl.py
View file @
4851c202
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
# Licensed under The MIT License [see LICENSE for details]
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
# --------------------------------------------------------
import
itertools
import
itertools
import
re
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
...
@@ -16,6 +17,7 @@ from transformers import PretrainedConfig
...
@@ -16,6 +17,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_pp_group
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -26,6 +28,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -26,6 +28,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_num_patches
)
get_clip_num_patches
)
...
@@ -95,8 +98,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
...
@@ -95,8 +98,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
:
int
,
def
calculate_num_blocks
(
orig_width
:
int
,
orig_height
:
int
,
min_num
:
int
,
max_num
:
int
,
max_num
:
int
,
image_size
:
int
,
image_size
:
int
)
->
Tuple
[
int
,
int
,
int
]:
use_thumbnail
:
bool
)
->
Tuple
[
int
,
int
,
int
]:
aspect_ratio
=
orig_width
/
orig_height
aspect_ratio
=
orig_width
/
orig_height
# calculate the existing image aspect ratio
# calculate the existing image aspect ratio
...
@@ -114,17 +117,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
...
@@ -114,17 +117,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
target_width
=
image_size
*
target_aspect_ratio
[
0
]
target_width
=
image_size
*
target_aspect_ratio
[
0
]
target_height
=
image_size
*
target_aspect_ratio
[
1
]
target_height
=
image_size
*
target_aspect_ratio
[
1
]
blocks
=
target_aspect_ratio
[
0
]
*
target_aspect_ratio
[
1
]
blocks
=
target_aspect_ratio
[
0
]
*
target_aspect_ratio
[
1
]
# add thumbnail image if num_blocks > 1
if
use_thumbnail
and
blocks
>
1
:
blocks
+=
1
return
blocks
,
target_width
,
target_height
return
blocks
,
target_width
,
target_height
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
def
dynamic_preprocess
(
image
:
Image
.
Image
,
min_num
:
int
,
max_num
:
int
,
image_size
:
int
,
image_size
:
int
,
use_thumbnail
:
int
)
->
List
[
Image
.
Image
]:
use_thumbnail
:
bool
)
->
List
[
Image
.
Image
]:
orig_width
,
orig_height
=
image
.
size
orig_width
,
orig_height
=
image
.
size
# calculate the number of blocks without thumbnail
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
blocks
,
target_width
,
target_height
=
calculate_num_blocks
(
orig_width
,
orig_height
,
min_num
,
max_num
,
image_size
)
orig_width
,
orig_height
,
min_num
,
max_num
,
image_size
,
use_thumbnail
=
False
)
# resize the image
# resize the image
resized_img
=
image
.
resize
((
target_width
,
target_height
))
resized_img
=
image
.
resize
((
target_width
,
target_height
))
processed_images
=
[]
processed_images
=
[]
...
@@ -197,19 +209,25 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -197,19 +209,25 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
downsample_ratio
)
downsample_ratio
)
image_data
=
multi_modal_data
[
"image"
]
image_data
=
multi_modal_data
[
"image"
]
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
use_thumbnail
=
hf_config
.
use_thumbnail
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
width
,
height
=
image_data
.
size
width
,
height
=
image_data
.
size
min_num
=
hf_config
.
min_dynamic_patch
max_num
=
hf_config
.
max_dynamic_patch
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
)
max_num
,
image_size
,
# add thumbnail image if num_blocks > 1
use_thumbnail
)
if
hf_config
.
use_thumbnail
and
num_blocks
>
1
:
image_feature_size
=
[
num_blocks
*
num_patches
]
num_blocks
+=
1
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
num_blocks
*
num_patches
image_feature_size
=
[]
for
image
in
image_data
:
width
,
height
=
image
.
size
num_blocks
,
_
,
_
=
calculate_num_blocks
(
width
,
height
,
min_num
,
max_num
,
image_size
,
use_thumbnail
)
image_feature_size
.
append
(
num_blocks
*
num_patches
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
else
:
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
...
@@ -220,8 +238,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -220,8 +238,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
if
prompt
is
None
:
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
image_feature_size
+
IMG_END
new_prompt
=
prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt
=
prompt
image_idx
=
sorted
(
map
(
int
,
re
.
findall
(
r
"Image-(\d+): <image>\n"
,
prompt
)))
for
idx
,
feature_size
in
enumerate
(
image_feature_size
,
start
=
1
):
image_prompt
=
IMG_START
+
IMG_CONTEXT
*
feature_size
+
IMG_END
if
not
image_idx
:
image_prompt
=
f
"Image-
{
idx
}
:
{
image_prompt
}
"
new_prompt
=
new_prompt
.
replace
(
'<image>'
,
image_prompt
,
1
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
return
LLMInputs
(
prompt
=
prompt
,
return
LLMInputs
(
prompt
=
prompt
,
...
@@ -245,6 +269,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
...
@@ -245,6 +269,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
use_thumbnail
=
use_thumbnail
)
use_thumbnail
=
use_thumbnail
)
# Add an N dimension for number of images per prompt (currently 1).
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
data
=
data
.
unsqueeze
(
0
)
elif
is_list_of
(
data
,
Image
.
Image
):
data
=
[
image_to_pixel_values
(
img
,
image_size
,
min_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
for
img
in
data
]
data
=
torch
.
stack
(
data
)
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
trust_remote_code
=
True
)
...
@@ -341,6 +374,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -341,6 +374,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
self
.
img_context_token_id
=
None
self
.
img_context_token_id
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
n
,
w
,
h
,
c
=
x
.
size
()
n
,
w
,
h
,
c
=
x
.
size
()
...
@@ -446,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -446,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
if
image_input
is
not
None
and
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
input_ids
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
...
@@ -461,7 +496,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -461,7 +496,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
positions
,
positions
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/jamba.py
View file @
4851c202
...
@@ -38,6 +38,8 @@ from vllm.sequence import IntermediateTensors
...
@@ -38,6 +38,8 @@ from vllm.sequence import IntermediateTensors
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
from
vllm.worker.model_runner
import
(
_BATCH_SIZES_TO_CAPTURE
,
_get_graph_batch_size
)
_get_graph_batch_size
)
from
.interfaces
import
SupportsLoRA
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -539,7 +541,7 @@ class JambaModel(nn.Module):
...
@@ -539,7 +541,7 @@ class JambaModel(nn.Module):
return
hidden_states
return
hidden_states
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
):
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -731,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
...
@@ -731,7 +733,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
indices_for_current_run
:
List
[
int
]):
indices_for_current_run
:
List
[
int
]):
# move out all of the occupied but currently not running blocks
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
# outside of the first n blocks
destination_indices
=
set
([
range
(
batch_size
)
])
destination_indices
=
range
(
batch_size
)
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
max_possible_batch_size
=
self
.
mamba_cache
[
0
].
shape
[
1
]
for
destination_index
in
destination_indices
:
for
destination_index
in
destination_indices
:
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
if
destination_index
in
self
.
_get_all_occupied_indices
()
and
\
...
...
vllm/model_executor/models/llama.py
View file @
4851c202
...
@@ -387,6 +387,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -387,6 +387,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping
=
{
"layers"
:
"model.layers"
,
"attention"
:
"self_attn"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wo"
:
"o_proj"
,
"attention_norm"
:
"input_layernorm"
,
"feed_forward"
:
"mlp"
,
"w1"
:
"gate_proj"
,
"w2"
:
"down_proj"
,
"w3"
:
"up_proj"
,
"ffn_norm"
:
"post_attention_layernorm"
,
"tok_embeddings"
:
"model.embed_tokens"
,
"output"
:
"lm_head"
,
"norm"
:
"model.norm"
}
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -493,6 +512,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -493,6 +512,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
name
,
loaded_weight
=
self
.
maybe_remap_mistral
(
name
,
loaded_weight
)
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
(
"rotary_emb.cos_cached"
in
name
if
(
"rotary_emb.cos_cached"
in
name
...
@@ -642,3 +663,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -642,3 +663,33 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
else
:
else
:
raise
RuntimeError
(
"Self attention has no KV cache scaling "
raise
RuntimeError
(
"Self attention has no KV cache scaling "
"factor attribute!"
)
"factor attribute!"
)
# This function is used to remap the mistral format as
# used by Mistral and Llama <=2
def
maybe_remap_mistral
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
)
->
Tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
,
n_heads
):
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
return
w
.
view
(
n_heads
,
attn_in
//
n_heads
//
2
,
2
,
attn_out
).
transpose
(
1
,
2
).
reshape
(
attn_in
,
attn_out
)
mapping
=
self
.
mistral_mapping
modules
=
name
.
split
(
"."
)
# rotary embeds should be sliced
if
"wk"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
)
elif
"wq"
in
modules
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
)
for
item
in
modules
:
if
item
in
mapping
and
mapping
[
item
]
not
in
name
:
name
=
name
.
replace
(
item
,
mapping
[
item
])
return
name
,
loaded_weight
vllm/model_executor/models/llava.py
View file @
4851c202
...
@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
...
@@ -4,6 +4,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
PIL
import
Image
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
,
SiglipVisionConfig
from
transformers
import
CLIPVisionConfig
,
LlavaConfig
,
SiglipVisionConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
...
@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -16,6 +17,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
dummy_seq_data_for_clip
,
get_max_clip_image_tokens
,
...
@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
...
@@ -24,7 +26,7 @@ from .interfaces import SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
,
input_processor_for_siglip
)
input_processor_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
from
.utils
import
(
filter_weights
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
...
@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -133,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
hf_config
=
ctx
.
get_hf_config
(
LlavaConfig
)
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
image_feature_size
=
get_max_llava_image_tokens
(
ctx
)
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[
get_max_llava_image_tokens
(
ctx
)
]
*
len
(
image_data
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
[
item
.
shape
[
1
]
for
item
in
image_data
]
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
input_processor_for_clip
(
return
input_processor_for_clip
(
...
@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -230,29 +243,24 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
return
None
return
None
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
if
not
isinstance
(
pixel_values
,
torch
.
Tensor
):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# Remove the N dimension until multiple images are supported.
pixel_values
=
pixel_values
.
squeeze
(
1
)
return
LlavaImagePixelInputs
(
return
LlavaImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)),
)
)
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
if
not
isinstance
(
image_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
image_embeds
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of image embeddings. "
raise
ValueError
(
"Incorrect type of image embeddings. "
f
"Got type:
{
type
(
image_embeds
)
}
"
)
f
"Got type:
{
type
(
image_embeds
)
}
"
)
# Remove the N dimension until multiple images are supported.
image_embeds
=
image_embeds
.
squeeze
(
1
)
return
LlavaImageEmbeddingInputs
(
return
LlavaImageEmbeddingInputs
(
type
=
"image_embeds"
,
type
=
"image_embeds"
,
data
=
image_embeds
,
data
=
flatten_bn
(
image_embeds
,
concat
=
True
),
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/llava_next.py
View file @
4851c202
...
@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -234,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
for
img
in
image_data
for
img
in
image_data
]
]
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
[
item
.
shape
[
1
]
for
item
in
image_data
]
else
:
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
...
...
vllm/model_executor/models/llava_next_video.py
0 → 100644
View file @
4851c202
import
itertools
import
math
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
transformers
import
(
CLIPVisionConfig
,
LlavaNextVideoConfig
,
SiglipVisionConfig
)
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_list_of
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
)
from
.utils
import
(
filter_weights
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
logger
=
init_logger
(
__name__
)
# For profile run
_MAX_FRAMES_PER_VIDEO
=
32
_MAX_NUM_VIDEOS
=
1
class
LlavaNextVideoPixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values_videos"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
Shape: `(batch_size, num_frames, num_channels, height, width)`
Note that `num_frames` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
Note that it only supports one video input for one batch.
"""
def
get_llava_next_video_frame_feature_size
(
hf_config
:
LlavaNextVideoConfig
)
->
int
:
# Support both CLIPVisionConfig and SiglipVisionConfig
image_size
=
hf_config
.
vision_config
.
image_size
patch_size
=
hf_config
.
vision_config
.
patch_size
spatial_pool_stride
=
hf_config
.
spatial_pool_stride
return
int
((
image_size
/
patch_size
/
spatial_pool_stride
)
**
2
)
def
_get_max_llm_tokens
(
ctx
:
InputContext
)
->
int
:
"""
Calculated from the maximum video frames under the context length
constraints of the language model.
"""
hf_text_config
=
ctx
.
model_config
.
hf_text_config
model_config
=
ctx
.
model_config
max_tokens
=
model_config
.
max_model_len
rope_scaling
=
model_config
.
rope_scaling
if
rope_scaling
:
rope_scaling_factor
=
hf_text_config
.
rope_scaling
[
"factor"
]
else
:
rope_scaling_factor
=
1
max_tokens
*=
rope_scaling_factor
return
max_tokens
def
get_max_llava_next_video_tokens
(
ctx
:
InputContext
)
->
int
:
# Currently set to 32 frames
# TODO: max_tokens = _get_max_llm_tokens(ctx)
hf_config
=
ctx
.
get_hf_config
(
LlavaNextVideoConfig
)
tokens_per_frame
=
get_llava_next_video_frame_feature_size
(
hf_config
)
return
_MAX_FRAMES_PER_VIDEO
*
tokens_per_frame
def
dummy_data_for_llava_next_video
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
hf_config
=
ctx
.
get_hf_config
(
LlavaNextVideoConfig
)
vision_config
=
hf_config
.
vision_config
# TODO: support multiple videos
num_videos
=
mm_counts
[
"video"
]
if
num_videos
!=
_MAX_NUM_VIDEOS
:
raise
NotImplementedError
(
f
"Only
{
_MAX_NUM_VIDEOS
}
videos are supported"
)
# TODO: support configuring the number of frames
frames_per_video
=
_MAX_FRAMES_PER_VIDEO
# num_images = num_videos * frames_per_video
# fills the sequence with as longer video data as possible
tokens_per_frame
=
get_llava_next_video_frame_feature_size
(
hf_config
)
video_feature_size
=
frames_per_video
*
tokens_per_frame
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
seq_data
=
dummy_seq_data_for_clip
(
vision_config
,
seq_len
,
num_videos
,
image_token_id
=
hf_config
.
video_token_index
,
image_feature_size_override
=
video_feature_size
,
)
pil_frame
=
dummy_image_for_clip
(
vision_config
,
num_images
=
1
)
np_frame
=
np
.
array
(
pil_frame
[
"image"
])
mm_data_per_video
=
np
.
repeat
([
np_frame
],
frames_per_video
,
axis
=
0
)
mm_data
=
{
"video"
:
mm_data_per_video
}
return
seq_data
,
mm_data
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
seq_data
=
dummy_seq_data_for_siglip
(
vision_config
,
seq_len
,
num_videos
,
image_token_id
=
hf_config
.
video_token_index
,
image_feature_size_override
=
video_feature_size
,
)
pil_frame
=
dummy_image_for_siglip
(
vision_config
,
num_images
=
1
)
np_frame
=
np
.
array
(
pil_frame
[
"image"
])
mm_data_per_video
=
np
.
repeat
([
np_frame
],
frames_per_video
,
axis
=
0
)
mm_data
=
{
"video"
:
mm_data_per_video
}
return
seq_data
,
mm_data
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
input_processor_for_llava_next_video
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"video"
not
in
multi_modal_data
:
return
llm_inputs
video_data
=
multi_modal_data
[
"video"
]
model_config
=
ctx
.
model_config
hf_config
=
ctx
.
get_hf_config
(
LlavaNextVideoConfig
)
vision_config
=
hf_config
.
vision_config
if
isinstance
(
video_data
,
np
.
ndarray
):
# Supports both CLIP and Siglip
num_frames
=
video_data
.
shape
[
0
]
frame_feature_size
=
\
get_llava_next_video_frame_feature_size
(
hf_config
)
video_feature_size
=
num_frames
*
frame_feature_size
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
placeholder_token_id
=
hf_config
.
video_token_index
,
repeat_count
=
video_feature_size
,
)
return
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
elif
is_list_of
(
video_data
,
np
.
ndarray
):
raise
NotImplementedError
(
"Processing multiple videos is not supported"
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
def
_init_vision_tower
(
hf_config
:
LlavaNextVideoConfig
):
vision_config
=
hf_config
.
vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer
=
hf_config
.
vision_feature_layer
if
vision_feature_layer
<
0
:
num_hidden_layers
=
hf_config
.
vision_config
.
num_hidden_layers
\
+
vision_feature_layer
+
1
else
:
num_hidden_layers
=
vision_feature_layer
+
1
if
isinstance
(
vision_config
,
CLIPVisionConfig
):
return
CLIPVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
elif
isinstance
(
vision_config
,
SiglipVisionConfig
):
return
SiglipVisionModel
(
vision_config
,
num_hidden_layers_override
=
num_hidden_layers
,
)
msg
=
f
"Unsupported vision config:
{
type
(
vision_config
)
}
"
raise
NotImplementedError
(
msg
)
# adopted from transformers modeling_llava_next_video.py
class
LlavaNextVideoPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
mode
=
config
.
spatial_pool_mode
stride
=
config
.
spatial_pool_stride
image_size
=
config
.
vision_config
.
image_size
patch_size
=
config
.
vision_config
.
patch_size
self
.
image_size
=
image_size
//
patch_size
**
2
if
mode
==
"average"
:
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
=
stride
,
stride
=
stride
)
elif
mode
==
"max"
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
=
stride
,
stride
=
stride
)
else
:
# TODO: Support Conv2d pooling layer, need to load weights
raise
ValueError
(
f
"Unknown pooling mode:
{
mode
}
. Expected [`average`, `max`]"
)
def
forward
(
self
,
image_features
):
ori_width
=
int
(
math
.
sqrt
(
image_features
.
shape
[
1
]
*
self
.
image_size
//
self
.
image_size
))
ori_height
=
int
(
ori_width
*
self
.
image_size
//
self
.
image_size
)
batch_size
,
_
,
dim
=
image_features
.
shape
image_features_spatial
=
image_features
\
.
view
(
batch_size
,
ori_height
,
ori_height
,
dim
)
\
.
permute
(
0
,
3
,
1
,
2
)
image_features_spatial
=
self
.
pool
(
image_features_spatial
)
return
image_features_spatial
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
class
LlavaNextMultiModalProjector
(
nn
.
Module
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
text_hidden_size
:
int
,
projector_hidden_act
:
str
):
super
().
__init__
()
self
.
linear_1
=
nn
.
Linear
(
vision_hidden_size
,
text_hidden_size
,
bias
=
True
)
self
.
act
=
get_act_fn
(
projector_hidden_act
)
self
.
linear_2
=
nn
.
Linear
(
text_hidden_size
,
text_hidden_size
,
bias
=
True
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
linear_1
(
image_features
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
linear_2
(
hidden_states
)
return
hidden_states
@
MULTIMODAL_REGISTRY
.
register_input_mapper
(
"video"
)
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
"video"
,
get_max_llava_next_video_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_llava_next_video
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_llava_next_video
)
class
LlavaNextVideoForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
LlavaNextVideoConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
# Initialize the vision tower only up to the required feature layer
self
.
vision_tower
=
_init_vision_tower
(
config
)
self
.
multi_modal_projector
=
LlavaNextMultiModalProjector
(
vision_hidden_size
=
config
.
vision_config
.
hidden_size
,
text_hidden_size
=
config
.
text_config
.
hidden_size
,
projector_hidden_act
=
config
.
projector_hidden_act
)
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
vision_resampler
=
LlavaNextVideoPooler
(
config
)
def
_validate_video_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
h
=
w
=
self
.
config
.
vision_config
.
image_size
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
[
2
:])
if
actual_dims
!=
expected_dims
:
expected_expr
=
(
"num_frames"
,
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values in each video frame "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
LlavaNextVideoPixelInputs
]:
"""
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values
=
kwargs
.
pop
(
"pixel_values_videos"
,
None
)
if
pixel_values
is
None
:
return
None
if
not
(
is_list_of
(
pixel_values
,
(
torch
.
Tensor
))
# different shape videos
or
isinstance
(
pixel_values
,
torch
.
Tensor
)):
# same shape videos
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
LlavaNextVideoPixelInputs
(
type
=
"pixel_values_videos"
,
data
=
pixel_values
,
)
def
_select_image_features
(
self
,
image_features
:
torch
.
Tensor
,
*
,
strategy
:
str
)
->
torch
.
Tensor
:
if
strategy
==
"default"
:
return
image_features
[:,
1
:]
elif
strategy
==
"full"
:
return
image_features
raise
ValueError
(
f
"Unexpected select feature strategy:
{
strategy
}
"
)
def
_video_pixels_to_features
(
self
,
vision_tower
:
Union
[
CLIPVisionModel
,
SiglipVisionModel
],
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features
=
vision_tower
(
pixel_values
)
image_features
=
self
.
_select_image_features
(
image_features
,
strategy
=
self
.
config
.
vision_feature_select_strategy
,
)
image_features
=
self
.
vision_resampler
(
image_features
)
image_features
=
self
.
multi_modal_projector
(
image_features
)
return
image_features
def
_process_video_pixels
(
self
,
inputs
:
LlavaNextVideoPixelInputs
):
assert
self
.
vision_tower
is
not
None
video_pixels
=
inputs
[
"data"
]
if
isinstance
(
video_pixels
,
torch
.
Tensor
):
# TODO: support multiple videos per input
b
,
num_videos
,
num_frames
,
c
,
h
,
w
=
video_pixels
.
shape
assert
(
num_videos
==
1
)
stacked_pixels
=
video_pixels
.
view
(
b
*
num_videos
*
num_frames
,
c
,
h
,
w
)
stacked_embeddings
=
self
.
_video_pixels_to_features
(
self
.
vision_tower
,
stacked_pixels
)
return
stacked_embeddings
.
view
(
b
,
num_frames
,
*
stacked_embeddings
.
shape
[
1
:])
elif
is_list_of
(
video_pixels
,
torch
.
Tensor
):
frames_per_videos
=
[
v
.
shape
[
0
]
for
v
in
video_pixels
]
stacked_pixels
=
torch
.
cat
(
video_pixels
,
dim
=
0
)
stacked_embeddings
=
self
.
_video_pixels_to_features
(
self
.
vision_tower
,
stacked_pixels
)
return
torch
.
split
(
stacked_embeddings
,
frames_per_videos
,
dim
=
0
)
else
:
raise
ValueError
(
f
"Unsupported type of video input
{
type
(
video_pixels
)
}
"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for LlaVA-NeXT-Video.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
# merge video embeddings into input embeddings
if
video_input
is
not
None
:
video_embeddings
=
self
.
_process_video_pixels
(
video_input
)
inputs_embeds
=
self
.
language_model
\
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeddings
,
self
.
config
.
video_token_index
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
# prepare weight iterators
vit_weights
,
mlp_weights
,
newline_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
4
)
# load vision encoder
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
self
.
vision_tower
.
load_weights
(
vit_weights
)
# load mlp projector
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
mlp_weights
:
param
=
mlp_params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# load llm backbone
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
self
.
language_model
.
load_weights
(
llm_weights
)
vllm/model_executor/models/minicpmv.py
View file @
4851c202
...
@@ -26,11 +26,9 @@ import re
...
@@ -26,11 +26,9 @@ import re
from
array
import
array
from
array
import
array
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
)
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.types
import
torch.types
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
...
@@ -44,6 +42,8 @@ from vllm.logger import init_logger
...
@@ -44,6 +42,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.resampler
import
(
Resampler2
,
get_2d_sincos_pos_embed
)
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
...
@@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs
...
@@ -98,101 +98,6 @@ MiniCPMVImageInputs = MiniCPMVImagePixelInputs
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
DEFAULT_LN
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
def
get_abs_pos
(
abs_pos
:
torch
.
Tensor
,
tgt_size
:
torch
.
Tensor
):
# abs_pos: L, C
# tgt_size: (H, W)
# return: M, C
src_size
=
int
(
math
.
sqrt
(
abs_pos
.
size
(
0
)))
# tgt_size = int(math.sqrt(tgt_size))
dtype
=
abs_pos
.
dtype
return
(
F
.
interpolate
(
abs_pos
.
float
().
reshape
(
1
,
src_size
,
src_size
,
-
1
).
permute
(
0
,
3
,
1
,
2
),
size
=
(
tgt_size
[
0
],
tgt_size
[
1
]),
mode
=
"bicubic"
,
align_corners
=
False
,
).
permute
(
0
,
2
,
3
,
1
).
flatten
(
0
,
2
).
to
(
dtype
=
dtype
))
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def
get_2d_sincos_pos_embed
(
embed_dim
:
int
,
grid_size
:
Union
[
int
,
Tuple
[
int
,
int
]],
cls_token
:
bool
=
False
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
),
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if
isinstance
(
grid_size
,
int
):
grid_h_size
,
grid_w_size
=
grid_size
,
grid_size
else
:
grid_h_size
,
grid_w_size
=
grid_size
[
0
],
grid_size
[
1
]
grid_h
=
np
.
arange
(
grid_h_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_w_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
if
version
==
(
2
,
0
):
grid
=
grid
.
reshape
([
2
,
1
,
grid_h_size
,
grid_w_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
else
:
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
,
version
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
grid
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
],
version
)
# (H*W, D/2) or (H, W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
],
version
)
# (H*W, D/2) or (H, W, D/2)
if
version
==
(
2
,
0
):
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
else
:
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=-
1
)
# (H, W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
:
int
,
pos
:
np
.
ndarray
,
version
:
Tuple
[
int
,
int
]
=
(
2
,
0
)):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,) / (H, W)
out: (M, D) / (H, W, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float32
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
if
version
==
(
2
,
0
):
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
else
:
out
=
np
.
einsum
(
"hw,d->hwd"
,
pos
,
omega
)
# (H, W, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (H, W, D/2)
emb_cos
=
np
.
cos
(
out
)
# (H, W, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=-
1
)
# (H, W, D)
return
emb
class
BaseResampler
(
nn
.
Module
):
class
BaseResampler
(
nn
.
Module
):
"""
"""
A 2D perceiver-resampler network with one cross attention layers by
A 2D perceiver-resampler network with one cross attention layers by
...
@@ -245,62 +150,6 @@ class BaseResampler(nn.Module):
...
@@ -245,62 +150,6 @@ class BaseResampler(nn.Module):
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
return
query
.
unsqueeze
(
1
).
repeat
(
1
,
N
,
1
)
class
Resampler2
(
BaseResampler
):
def
__init__
(
self
,
grid_size
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
kv_dim
:
Optional
[
int
]
=
None
,
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
]
=
DEFAULT_LN
,
adaptive
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
grid_size
**
2
,
embed_dim
,
num_heads
,
kv_dim
,
norm_layer
)
self
.
adaptive
=
adaptive
pos_embed_arr
=
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
version
=
(
2
,
0
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
from_numpy
(
pos_embed_arr
).
float
()).
requires_grad_
(
False
)
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
tgt_sizes
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
if
self
.
adaptive
:
pos_embed_arr
=
get_2d_sincos_pos_embed
(
self
.
embed_dim
,
tgt_sizes
,
version
=
(
2
,
0
))
pos_embed
=
torch
.
from_numpy
(
pos_embed_arr
).
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
else
:
pos_embed
=
get_abs_pos
(
self
.
pos_embed
,
tgt_sizes
)
x
,
_
=
self
.
kv_proj
(
x
)
x
=
self
.
ln_kv
(
x
).
permute
(
1
,
0
,
2
)
N
=
x
.
shape
[
1
]
q
=
self
.
ln_q
(
self
.
query
)
out
=
self
.
attn
(
self
.
_repeat
(
q
,
N
)
+
self
.
pos_embed
.
unsqueeze
(
1
),
x
+
pos_embed
.
unsqueeze
(
1
),
x
,
attn_mask
=
attn_mask
,
)[
0
]
x
=
out
.
permute
(
1
,
0
,
2
)
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
class
Resampler2_5
(
BaseResampler
):
class
Resampler2_5
(
BaseResampler
):
def
__init__
(
def
__init__
(
...
@@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
...
@@ -782,7 +631,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
num_heads
=
embed_dim
//
128
,
num_heads
=
embed_dim
//
128
,
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
grid_size
=
int
(
math
.
sqrt
(
self
.
config
.
query_num
)),
kv_dim
=
vision_dim
,
kv_dim
=
vision_dim
,
adaptive
=
True
,
adaptive
=
False
,
do_post_projection
=
True
,
)
)
return
resampler
return
resampler
...
...
vllm/model_executor/models/mixtral.py
View file @
4851c202
...
@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -435,7 +435,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -454,6 +455,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -464,7 +468,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
...
vllm/model_executor/models/paligemma.py
View file @
4851c202
import
itertools
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
...
@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -13,7 +14,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.gemma
import
Gemma
Model
from
vllm.model_executor.models.gemma
import
Gemma
ForCausalLM
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
...
@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
...
@@ -22,14 +23,10 @@ from vllm.sequence import IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
filter_weights
,
merge_multimodal_embeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_KEYS_TO_MODIFY_MAPPING
=
{
"language_model.model"
:
"language_model"
,
}
class
PaliGemmaImagePixelInputs
(
TypedDict
):
class
PaliGemmaImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
...
@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -151,8 +148,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
projection_dim
=
config
.
vision_config
.
projection_dim
)
projection_dim
=
config
.
vision_config
.
projection_dim
)
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
language_model
=
Gemma
Model
(
config
.
text_config
,
cache_config
,
self
.
language_model
=
Gemma
ForCausalLM
(
config
.
text_config
,
quant_config
)
cache_config
,
quant_config
)
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
self
.
unpadded_vocab_size
=
config
.
text_config
.
vocab_size
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
...
@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -252,7 +249,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
vision_embeddings
=
vision_embeddings
*
(
self
.
config
.
hidden_size
**
vision_embeddings
=
vision_embeddings
*
(
self
.
config
.
hidden_size
**
-
0.5
)
-
0.5
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
input_ids
,
inputs_embeds
,
vision_embeddings
,
...
@@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -262,87 +260,47 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
else
:
else
:
inputs_embeds
=
None
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
None
,
None
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
# Copied from vllm/model_executor/models/gemma.py
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
embed_tokens
,
return
self
.
language_model
.
compute_logits
(
hidden_states
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
logits
# Copied from vllm/model_executor/models/gemma.py
def
sample
(
def
sample
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
return
next_tokens
# Adapted from vllm/model_executor/models/gemma.py
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# prepare weight iterators for components
# (param_name, shard_name, shard_id)
vit_weights
,
mlp_weights
,
llm_weights
=
itertools
.
tee
(
weights
,
3
)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
# load vision tower
(
"qkv_proj"
,
"v_proj"
,
"v"
),
vit_weights
=
filter_weights
(
vit_weights
,
"vision_tower"
)
(
"gate_up_proj"
,
"gate_proj"
,
0
),
self
.
vision_tower
.
load_weights
(
vit_weights
)
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
# load mlp projector
params_dict
=
dict
(
self
.
named_parameters
())
mlp_weights
=
filter_weights
(
mlp_weights
,
"multi_modal_projector"
)
loaded_params
=
set
()
mlp_params_dict
=
dict
(
self
.
multi_modal_projector
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
mlp_weights
:
for
key_to_modify
,
new_key
in
_KEYS_TO_MODIFY_MAPPING
.
items
():
param
=
mlp_params_dict
[
name
]
if
key_to_modify
in
name
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
name
=
name
.
replace
(
key_to_modify
,
new_key
)
default_weight_loader
)
use_default_weight_loading
=
False
weight_loader
(
param
,
loaded_weight
)
if
"vision"
not
in
name
or
self
.
vision_tower
.
shard_weight
:
for
(
param_name
,
shard_name
,
# load llm backbone
shard_id
)
in
stacked_params_mapping
:
llm_weights
=
filter_weights
(
llm_weights
,
"language_model"
)
if
shard_name
not
in
name
:
self
.
language_model
.
load_weights
(
llm_weights
)
continue
name
=
name
.
replace
(
shard_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# lm_head is not used in vllm as it is tied with
# embed_token. To prevent errors, skip loading
# lm_head.weight.
if
"lm_head.weight"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
use_default_weight_loading
=
True
else
:
use_default_weight_loading
=
True
if
use_default_weight_loading
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
logger
.
warning
(
"Some weights are not initialized from checkpoints: %s"
,
unloaded_params
)
vllm/model_executor/models/phi3v.py
View file @
4851c202
...
@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -424,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width
=
w
,
input_width
=
w
,
input_height
=
h
))
input_height
=
h
))
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
elif
is_list_of
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
[
item
.
shape
[
1
]
for
item
in
image_data
]
else
:
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
...
...
vllm/model_executor/models/pixtral.py
0 → 100644
View file @
4851c202
import
math
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
from
typing
import
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
mistral_common.protocol.instruct.messages
import
ImageChunk
from
PIL
import
Image
from
transformers
import
PretrainedConfig
from
xformers.ops.fmha
import
memory_efficient_attention
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
init_vllm_registered_model
def
get_max_pixtral_image_tokens
(
ctx
:
InputContext
):
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
max_image_size
=
mm_encoder
.
mm_config
.
max_image_size
image_patch_size
=
mm_encoder
.
mm_config
.
image_patch_size
return
((
max_image_size
//
image_patch_size
)
**
2
)
def
dummy_data_for_pixtral
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]):
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
mm_config
=
ctx
.
model_config
.
multimodal_config
max_num_images_per_request
=
mm_config
.
limit_per_prompt
.
get
(
"image"
,
1
)
# approximate image size
size
=
int
(
math
.
sqrt
(
seq_len
)
*
mm_encoder
.
mm_config
.
image_patch_size
)
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
img_chunk
=
ImageChunk
(
image
=
image
)
tokens
=
mm_encoder
(
img_chunk
).
tokens
token_ids
=
max_num_images_per_request
*
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
tokens
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
max_num_images_per_request
*
[
image
]}
return
seq_data
,
mm_data
def
input_mapper_for_pixtral
(
ctx
:
InputContext
,
data
:
object
)
->
MultiModalInputs
:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer_mode
=
model_config
.
tokenizer_mode
)
data_list
=
data
if
isinstance
(
data
,
list
)
else
[
data
]
images
=
[]
for
image_data
in
data_list
:
image
=
ImageChunk
(
image
=
image_data
)
encoding
=
tokenizer
.
instruct
.
mm_encoder
(
image
)
image
=
torch
.
from_numpy
(
encoding
.
image
).
to
(
device
=
"cuda"
,
dtype
=
torch
.
float16
)
images
.
append
(
image
)
return
MultiModalInputs
({
"images"
:
images
})
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
image_features
:
Optional
[
List
[
torch
.
Tensor
]],
image_id
:
int
)
->
torch
.
Tensor
:
text_locations
=
input_ids
!=
image_id
image_locations
=
input_ids
==
image_id
seq_len
=
input_ids
.
shape
[
0
]
N_txt
=
text_locations
.
sum
().
item
()
_
,
D_txt
=
inputs_embeds
.
shape
N_img
,
D_img
=
image_features
.
shape
assert
(
D_txt
==
D_img
),
(
f
"Text features dim
{
D_txt
}
should be equal "
"to image features dim {D_img}"
)
assert
(
seq_len
==
N_txt
+
N_img
),
(
f
"seq_len
{
seq_len
}
should be equal to N_txt + N_img "
f
"
{
(
N_txt
,
N_img
,
image_locations
.
sum
().
item
())
}
"
)
inputs_embeds
[
image_locations
,
:]
=
image_features
return
inputs_embeds
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_pixtral_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_pixtral
)
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
dataclass_fields
=
{
field
.
name
for
field
in
fields
(
VisionEncoderArgs
)}
vision_args
=
{
key
:
value
for
key
,
value
in
self
.
config
.
vision_config
.
to_dict
().
items
()
if
key
in
dataclass_fields
}
self
.
vision_args
=
VisionEncoderArgs
(
**
vision_args
)
# init MistralForCausalLM
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
self
.
vision_encoder
=
VisionTransformer
(
self
.
vision_args
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for pixtral.
TODO
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
vision_args
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
None
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
_parse_and_validate_image_input
(
self
,
images
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
None
)
->
Optional
[
List
[
torch
.
Tensor
]]:
if
images
is
None
:
return
None
if
isinstance
(
images
,
torch
.
Tensor
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
images
.
size
(
1
))]
elif
isinstance
(
images
,
list
):
# always take last images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
len
(
images
[
0
]))]
return
images
def
_process_image_input
(
self
,
image_input
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
self
.
vision_language_adapter
(
self
.
vision_encoder
(
image_input
))
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
is_vision_encoder_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
(
"vision_encoder"
)
def
is_vision_lang_adapter_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
weight
[
0
].
startswith
(
"vision_language_adapter"
)
def
is_vision_weights
(
weight
:
Tuple
[
str
,
torch
.
Tensor
]):
return
is_vision_encoder_weights
(
weight
)
or
is_vision_lang_adapter_weights
(
weight
)
llm_weights
,
vision_encoder_weights
,
vision_lang_adapter_weights
=
tee
(
weights
,
3
)
# llm
llm_weights
=
filter
(
lambda
x
:
not
is_vision_weights
(
x
),
llm_weights
)
self
.
language_model
.
load_weights
(
llm_weights
)
# vision encoder
vision_encoder_weights
=
filter
(
is_vision_encoder_weights
,
vision_encoder_weights
)
vision_encoder_dict
=
dict
(
self
.
vision_encoder
.
named_parameters
())
for
name
,
loaded_weight
in
vision_encoder_weights
:
# cut 'vision_encoder.'
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_encoder_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
# adapter
vision_lang_adapter_weights
=
filter
(
is_vision_lang_adapter_weights
,
vision_lang_adapter_weights
)
vision_lang_adpter_dict
=
dict
(
self
.
vision_language_adapter
.
named_parameters
())
for
name
,
loaded_weight
in
vision_lang_adapter_weights
:
# cut 'vision_language_adapter.'
name
=
'.'
.
join
(
name
.
split
(
"."
)[
1
:])
param
=
vision_lang_adpter_dict
[
name
]
default_weight_loader
(
param
,
loaded_weight
)
# Vision encoder
@
dataclass
class
VisionEncoderArgs
:
hidden_size
:
int
num_channels
:
int
image_size
:
int
patch_size
:
int
intermediate_size
:
int
num_hidden_layers
:
int
num_attention_heads
:
int
rope_theta
:
float
# for rope-2D
image_token_id
:
int
def
_reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
freqs_cis: complex - (seq_len, head_dim / 2)
x: complex - (bsz, seq_len, head_dim / 2)
"""
ndim
=
x
.
ndim
assert
ndim
>
1
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
]),
(
freqs_cis
.
shape
,
(
x
.
shape
[
1
],
x
.
shape
[
-
1
]),
)
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)
]
return
freqs_cis
.
view
(
*
shape
)
def
precompute_freqs_cis_2d
(
dim
:
int
,
height
:
int
,
width
:
int
,
theta
:
float
,
)
->
torch
.
Tensor
:
"""
freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
to be indexed by (height, width) position tuples
"""
# (dim / 2) frequency bases
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
h
=
torch
.
arange
(
height
,
device
=
freqs
.
device
)
w
=
torch
.
arange
(
width
,
device
=
freqs
.
device
)
freqs_h
=
torch
.
outer
(
h
,
freqs
[::
2
]).
float
()
freqs_w
=
torch
.
outer
(
w
,
freqs
[
1
::
2
]).
float
()
freqs_2d
=
torch
.
cat
(
[
freqs_h
[:,
None
,
:].
repeat
(
1
,
width
,
1
),
freqs_w
[
None
,
:,
:].
repeat
(
height
,
1
,
1
),
],
dim
=-
1
,
)
return
torch
.
polar
(
torch
.
ones_like
(
freqs_2d
),
freqs_2d
)
def
apply_rotary_emb_vit
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
-
1
,
2
))
assert
freqs_cis
.
dtype
==
torch
.
complex64
freqs_cis
=
_reshape_for_broadcast
(
freqs_cis
,
xq_
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
assert
args
.
intermediate_size
is
not
None
self
.
w1
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
intermediate_size
,
bias
=
False
)
self
.
w2
=
nn
.
Linear
(
args
.
intermediate_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
w3
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
intermediate_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
w2
(
F
.
silu
(
self
.
w1
(
x
))
*
self
.
w3
(
x
))
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
args
=
args
assert
not
args
.
hidden_size
%
args
.
num_attention_heads
self
.
n_heads
=
args
.
num_attention_heads
self
.
head_dim
=
args
.
hidden_size
//
args
.
num_attention_heads
self
.
wq
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wk
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wv
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
self
.
wo
=
nn
.
Linear
(
args
.
hidden_size
,
args
.
hidden_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
batch
,
patches
,
_
=
x
.
shape
q
,
k
,
v
=
self
.
wq
(
x
),
self
.
wk
(
x
),
self
.
wv
(
x
)
q
=
q
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
k
=
k
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
v
=
v
.
reshape
(
batch
,
patches
,
self
.
n_heads
,
self
.
head_dim
)
q
,
k
=
apply_rotary_emb_vit
(
q
,
k
,
freqs_cis
=
freqs_cis
)
out
=
memory_efficient_attention
(
q
,
k
,
v
,
attn_bias
=
mask
)
out
=
out
.
reshape
(
batch
,
patches
,
self
.
n_heads
*
self
.
head_dim
)
return
self
.
wo
(
out
)
class
TransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
attention
=
Attention
(
args
)
self
.
feed_forward
=
FeedForward
(
args
)
self
.
attention_norm
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
self
.
ffn_norm
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
r
=
self
.
attention
.
forward
(
self
.
attention_norm
(
x
),
mask
=
mask
,
freqs_cis
=
freqs_cis
)
h
=
x
+
r
r
=
self
.
feed_forward
.
forward
(
self
.
ffn_norm
(
h
))
out
=
h
+
r
return
out
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
layers
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
args
.
num_hidden_layers
):
self
.
layers
.
append
(
TransformerBlock
(
args
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
mask
:
BlockDiagonalMask
,
freqs_cis
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
mask
=
mask
,
freqs_cis
=
freqs_cis
)
return
x
def
position_meshgrid
(
patch_embeds_list
:
list
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
positions
=
torch
.
cat
([
torch
.
stack
(
torch
.
meshgrid
(
torch
.
arange
(
p
.
shape
[
-
2
]),
torch
.
arange
(
p
.
shape
[
-
1
]),
indexing
=
"ij"
,
),
dim
=-
1
,
).
reshape
(
-
1
,
2
)
for
p
in
patch_embeds_list
])
return
positions
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
):
super
().
__init__
()
self
.
args
=
args
self
.
patch_conv
=
nn
.
Conv2d
(
in_channels
=
args
.
num_channels
,
out_channels
=
args
.
hidden_size
,
kernel_size
=
args
.
patch_size
,
stride
=
args
.
patch_size
,
bias
=
False
,
)
self
.
ln_pre
=
RMSNorm
(
args
.
hidden_size
,
eps
=
1e-5
)
self
.
transformer
=
Transformer
(
args
)
head_dim
=
self
.
args
.
hidden_size
//
self
.
args
.
num_attention_heads
assert
head_dim
%
2
==
0
,
"ROPE requires even head_dim"
self
.
_freqs_cis
:
Optional
[
torch
.
Tensor
]
=
None
@
property
def
max_patches_per_side
(
self
)
->
int
:
return
self
.
args
.
image_size
//
self
.
args
.
patch_size
@
property
def
device
(
self
)
->
torch
.
device
:
return
next
(
self
.
parameters
()).
device
@
property
def
dtype
(
self
)
->
torch
.
device
:
return
next
(
self
.
parameters
()).
dtype
@
property
def
freqs_cis
(
self
)
->
torch
.
Tensor
:
if
self
.
_freqs_cis
is
None
:
self
.
_freqs_cis
=
precompute_freqs_cis_2d
(
dim
=
self
.
args
.
hidden_size
//
self
.
args
.
num_attention_heads
,
height
=
self
.
max_patches_per_side
,
width
=
self
.
max_patches_per_side
,
theta
=
self
.
args
.
rope_theta
,
)
if
self
.
_freqs_cis
.
device
!=
self
.
device
:
self
.
_freqs_cis
=
self
.
_freqs_cis
.
to
(
device
=
self
.
device
)
return
self
.
_freqs_cis
def
forward
(
self
,
images
:
List
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
"""
Args:
images: list of N_img images of variable sizes,
each of shape (C, H, W)
Returns:
image_features: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
# pass images through initial convolution independently
patch_embeds_list
=
[
self
.
patch_conv
(
img
.
unsqueeze
(
0
).
to
(
self
.
dtype
))
for
img
in
images
]
# flatten to a single sequence
patch_embeds
=
torch
.
cat
(
[
p
.
flatten
(
2
).
permute
(
0
,
2
,
1
)
for
p
in
patch_embeds_list
],
dim
=
1
)
patch_embeds
=
self
.
ln_pre
(
patch_embeds
)
# positional embeddings
positions
=
position_meshgrid
(
patch_embeds_list
).
to
(
self
.
device
)
freqs_cis
=
self
.
freqs_cis
[
positions
[:,
0
],
positions
[:,
1
]]
# pass through Transformer with a block diagonal mask delimiting images
mask
=
BlockDiagonalMask
.
from_seqlens
(
[
p
.
shape
[
-
2
]
*
p
.
shape
[
-
1
]
for
p
in
patch_embeds_list
],
)
out
=
self
.
transformer
(
patch_embeds
,
mask
=
mask
,
freqs_cis
=
freqs_cis
)
# remove batch dimension of the single sequence
return
out
.
squeeze
(
0
)
class
VisionLanguageAdapter
(
nn
.
Module
):
def
__init__
(
self
,
args
:
VisionEncoderArgs
,
dim
:
int
):
super
().
__init__
()
assert
isinstance
(
args
,
VisionEncoderArgs
)
self
.
w_in
=
nn
.
Linear
(
args
.
hidden_size
,
dim
,
bias
=
True
,
)
self
.
gelu
=
nn
.
GELU
()
self
.
w_out
=
nn
.
Linear
(
dim
,
dim
,
bias
=
True
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
w_out
(
self
.
gelu
(
self
.
w_in
(
x
)))
vllm/model_executor/models/qwen.py
View file @
4851c202
...
@@ -4,41 +4,408 @@
...
@@ -4,41 +4,408 @@
# Copyright (c) Alibaba Cloud.
# Copyright (c) Alibaba Cloud.
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
"""Inference-only QWen model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
import
math
import
re
from
array
import
array
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
import
numpy
as
np
import
torch
import
torch
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
os
import
os
import
re
import
re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
,
get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.resampler
import
Resampler2
,
get_abs_pos
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.utils
import
print_warning_once
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
is_pp_missing_parameter
,
make_layers
logger
=
init_logger
(
__name__
)
# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad;
# for the time being, these tags are not considered as special at encoding
# time. This may change as VLLMs multimodal API changes in the future.
IMG_START
=
"<img>"
IMG_END
=
"</img>"
IMG_PAD
=
"<imgpad>"
# Image context is fixed at 256 for all images
MAX_QWEN_IMG_TOKENS
=
256
# Image normalization params
CLIP_MEAN
=
(
0.48145466
,
0.4578275
,
0.40821073
)
CLIP_STD
=
(
0.26862954
,
0.26130258
,
0.27577711
)
class
QwenImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""
Shape: `(batch_size * num_images, 3, image_size, image_size)`
Note that image_size is the value in the vision config to which we resize
the image to in the normalization transform. Currently multi-image support
can only be leveraged by passing image embeddings directly.
"""
class
QwenImageEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"image_embeds"
]
data
:
torch
.
Tensor
"""Shape: `(batch_size * num_images, 256, hidden_size)`
`hidden_size` must match the hidden size of the language model backbone
and is stored in the visual config of the model if we have one.
"""
QwenImageInputs
=
Union
[
QwenImagePixelInputs
,
QwenImageEmbeddingInputs
]
class
VisualAttention
(
nn
.
Module
):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
kdim
:
Optional
[
int
]
=
None
,
vdim
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
_qkv_same_embed_dim
=
self
.
kdim
==
embed_dim
\
and
self
.
vdim
==
embed_dim
self
.
num_heads
=
num_heads
# Per attention head and per partition values.
assert
embed_dim
%
num_heads
==
0
self
.
hidden_size_per_attention_head
=
embed_dim
//
num_heads
self
.
num_attention_heads_per_partition
=
num_heads
self
.
hidden_size_per_partition
=
embed_dim
# Strided linear layer.
assert
self
.
_qkv_same_embed_dim
,
\
'Visual Attention implementation only supports self-attention'
self
.
in_proj
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
# query/key/value: [sq, b, h]
sq
,
b
,
_
=
x
.
size
()
mixed_x_layer
=
self
.
in_proj
(
x
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer
,
key_layer
,
value_layer
=
mixed_x_layer
.
split
(
self
.
hidden_size_per_attention_head
,
dim
=-
1
)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
sq
,
b
*
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
).
transpose
(
0
,
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
sq
,
b
*
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
).
transpose
(
0
,
1
)
q_scaled
=
query_layer
/
self
.
norm_factor
if
attn_mask
is
not
None
:
attention_probs
=
torch
.
baddbmm
(
attn_mask
,
q_scaled
,
key_layer
.
transpose
(
-
2
,
-
1
))
else
:
attention_probs
=
torch
.
bmm
(
q_scaled
,
key_layer
.
transpose
(
-
2
,
-
1
))
attention_probs
=
attention_probs
.
softmax
(
dim
=-
1
)
value_layer
=
value_layer
.
view
(
sq
,
b
*
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
).
transpose
(
0
,
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
)
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
b
,
self
.
num_attention_heads_per_partition
,
sq
,
self
.
hidden_size_per_attention_head
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
output
=
self
.
out_proj
(
context_layer
)
return
output
class
QwenVMLP
(
nn
.
Module
):
"""MLP for the visual component of the Qwen model."""
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
)
self
.
act_fn
=
get_act_fn
(
"gelu"
,
quant_config
,
intermediate_size
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
)
def
forward
(
self
,
x
):
x
,
_
=
self
.
c_fc
(
x
)
x
=
self
.
act_fn
(
x
)
x
,
_
=
self
.
c_proj
(
x
)
return
x
class
VisualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_head
:
int
,
mlp_ratio
:
float
=
4.0
,
norm_layer
:
Callable
=
nn
.
LayerNorm
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
ln_1
=
norm_layer
(
d_model
)
self
.
ln_2
=
norm_layer
(
d_model
)
mlp_width
=
int
(
d_model
*
mlp_ratio
)
self
.
attn
=
VisualAttention
(
d_model
,
n_head
)
self
.
mlp
=
QwenVMLP
(
hidden_size
=
d_model
,
intermediate_size
=
mlp_width
,
quant_config
=
quant_config
,
)
def
attention
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
attn_mask
=
attn_mask
.
to
(
x
.
dtype
)
if
attn_mask
is
not
None
else
None
return
self
.
attn
(
x
,
attn_mask
=
attn_mask
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attention
(
self
.
ln_1
(
x
),
attn_mask
=
attn_mask
)
x
=
x
+
self
.
mlp
(
self
.
ln_2
(
x
))
return
x
class
TransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
width
:
int
,
layers
:
int
,
heads
:
int
,
mlp_ratio
:
float
=
4.0
,
norm_layer
:
Callable
=
nn
.
LayerNorm
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
width
=
width
self
.
layers
=
layers
self
.
resblocks
=
nn
.
ModuleList
([
VisualAttentionBlock
(
width
,
heads
,
mlp_ratio
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
)
for
_
in
range
(
layers
)
])
def
get_cast_dtype
(
self
)
->
torch
.
dtype
:
return
self
.
resblocks
[
0
].
mlp
.
c_fc
.
weight
.
dtype
def
get_cast_device
(
self
)
->
torch
.
device
:
return
self
.
resblocks
[
0
].
mlp
.
c_fc
.
weight
.
device
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
for
r
in
self
.
resblocks
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
image_size
:
int
,
patch_size
:
int
,
width
:
int
,
layers
:
int
,
heads
:
int
,
mlp_ratio
:
float
,
n_queries
:
int
=
256
,
output_dim
:
int
=
512
,
image_start_id
:
int
=
151857
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
**
kwargs
):
super
().
__init__
()
image_height
,
image_width
=
self
.
image_size
=
(
image_size
,
image_size
)
patch_height
,
patch_width
=
self
.
patch_size
=
(
patch_size
,
patch_size
)
self
.
grid_size
=
(
image_height
//
patch_height
,
image_width
//
patch_width
)
self
.
output_dim
=
output_dim
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
width
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
False
)
# class embeddings and positional embeddings
scale
=
width
**-
0.5
self
.
positional_embedding
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
256
,
width
))
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
ln_pre
=
norm_layer
(
width
)
self
.
transformer
=
TransformerBlock
(
width
,
layers
,
heads
,
mlp_ratio
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
)
self
.
attn_pool
=
Resampler2
(
grid_size
=
int
(
math
.
sqrt
(
n_queries
)),
embed_dim
=
output_dim
,
num_heads
=
output_dim
//
128
,
kv_dim
=
width
,
norm_layer
=
norm_layer
,
adaptive
=
False
,
do_post_projection
=
False
,
).
to
(
device
=
self
.
positional_embedding
.
device
,
dtype
=
self
.
positional_embedding
.
dtype
,
)
self
.
ln_post
=
norm_layer
(
output_dim
)
self
.
proj
=
nn
.
Parameter
(
(
output_dim
**-
0.5
)
*
torch
.
randn
(
output_dim
,
output_dim
))
self
.
image_start_id
=
image_start_id
self
.
image_end_id
=
image_start_id
+
1
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
.
to
(
dtype
=
self
.
transformer
.
get_cast_dtype
(),
device
=
self
.
transformer
.
get_cast_device
(),
)
# to patches
x
=
self
.
conv1
(
x
)
# shape = [*, width, grid, grid]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
-
1
)
# shape = [*, width, grid ** 2]
x
=
x
.
permute
(
0
,
2
,
1
)
# shape = [*, grid ** 2, width]
x
=
x
+
get_abs_pos
(
self
.
positional_embedding
,
int
(
math
.
sqrt
(
x
.
size
(
1
))))
x
=
self
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
attn_pool
(
x
)
x
=
self
.
ln_post
(
x
)
x
=
x
@
self
.
proj
return
x
def
get_image_positions
(
self
,
input_ids
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""Given the input IDs, extracts start/stop points corresponding to
images.
args:
Returns:
Optional torch tensor corresponding to start/stop pairs of images.
"""
if
torch
.
any
(
input_ids
==
self
.
image_start_id
):
bos_pos
=
torch
.
where
(
input_ids
==
self
.
image_start_id
)
eos_pos
=
torch
.
where
(
input_ids
==
self
.
image_end_id
)
return
torch
.
stack
((
bos_pos
[
0
],
eos_pos
[
0
]),
dim
=
1
)
return
None
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
"""MLP for the language component of the Qwen model, which contains a
MergedColumnParallelLinear merging 2 outputs via silu activation."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -61,7 +428,7 @@ class QWenMLP(nn.Module):
...
@@ -61,7 +428,7 @@ class QWenMLP(nn.Module):
"Only silu is supported for now."
)
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
)
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
c_proj
(
x
)
x
,
_
=
self
.
c_proj
(
x
)
...
@@ -215,6 +582,9 @@ class QWenModel(nn.Module):
...
@@ -215,6 +582,9 @@ class QWenModel(nn.Module):
lambda
prefix
:
QWenBlock
(
config
,
cache_config
,
quant_config
),
lambda
prefix
:
QWenBlock
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
)
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
visual
=
VisionTransformer
(
**
config
.
visual
,
quant_config
=
quant_config
)
if
hasattr
(
config
,
"visual"
)
else
None
def
forward
(
def
forward
(
self
,
self
,
...
@@ -223,9 +593,33 @@ class QWenModel(nn.Module):
...
@@ -223,9 +593,33 @@ class QWenModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
pixel_values
:
Optional
[
QwenImageInputs
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
img_pos
=
None
# If pixel / visual embeddings are provided, this is a visual model
if
pixel_values
is
not
None
and
self
.
visual
is
not
None
:
if
pixel_values
[
"type"
]
!=
"image_embeds"
:
image_embeds
=
self
.
visual
(
pixel_values
[
"data"
])
else
:
image_embeds
=
pixel_values
[
"data"
]
# features should be of shape (# images, 256, hidden_dim)
img_pos
=
self
.
visual
.
get_image_positions
(
input_ids
)
if
isinstance
(
img_pos
,
np
.
ndarray
)
and
img_pos
.
shape
[
0
]
!=
image_embeds
.
shape
[
0
]:
raise
ValueError
(
f
"Number of placeholders:
{
img_pos
.
shape
[
0
]
}
"
f
"does not match number of images
{
image_embeds
.
shape
[
0
]
}
."
)
if
get_pp_group
().
is_first_rank
:
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
if
img_pos
is
not
None
:
for
idx
,
(
img_bos
,
img_eos
)
in
enumerate
(
img_pos
):
hidden_states
[
img_bos
+
1
:
img_eos
]
=
image_embeds
[
idx
]
residual
=
None
residual
=
None
else
:
else
:
assert
intermediate_tensors
is
not
None
assert
intermediate_tensors
is
not
None
...
@@ -249,16 +643,241 @@ class QWenModel(nn.Module):
...
@@ -249,16 +643,241 @@ class QWenModel(nn.Module):
return
hidden_states
return
hidden_states
class
QWenLMHeadModel
(
nn
.
Module
):
def
get_image_text
(
image_num
:
int
,
padding
:
bool
)
->
str
:
"""Retrieves a placeholder text that when tokenized, will be expanded with
image pads.
Args:
image_num: The number of the image that we want a text prompt for.
Images should be indexed starting at 1.
padding: Whether or not padding should be manually added.
Returns:
Text placeholder prompt for the image being considered.
"""
image_start
=
f
"Picture
{
image_num
}
:
{
IMG_START
}
"
image_end
=
f
"
{
IMG_END
}
\n
"
if
not
padding
:
return
f
"
{
image_start
}{
image_end
}
"
return
f
"
{
image_start
}{
MAX_QWEN_IMG_TOKENS
*
IMG_PAD
}{
image_end
}
"
def
input_processor_for_qwen
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
)
->
LLMInputs
:
"""Processes the inputs, which may or may not be multimodal.
Multimodal inputs will only be processed if the model has a "visual"
component in its model config, otherwise they'll be ignored.
Args:
ctx: Context of the loaded model.
llm_inputs: LLM inputs which may have a multi_modal_data attribute.
Returns:
If the model is language only or not multimodal inputs were provided,
returns llm_inputs unmodified. Otherwise, processes the multimodal
images / image embeddings and adds the fixed-length image placeholders.
"""
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
# Only process images if we have multimodal data and a visual config
hf_config
=
ctx
.
get_hf_config
()
if
(
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
or
not
hasattr
(
hf_config
,
"visual"
)):
return
llm_inputs
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
torch
.
Tensor
):
num_dims
=
len
(
image_data
.
shape
)
if
num_dims
<
2
or
num_dims
>
3
:
raise
ValueError
(
f
"Expected img embeds to be have 3 dimensions, got
{
num_dims
}
"
)
num_images
=
1
if
num_dims
==
2
else
image_data
.
shape
[
0
]
else
:
# TODO - handle multiple image inputs once the API is solidified
num_images
=
1
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
new_prompt
,
num_matched_images
=
re
.
subn
(
r
"(Picture \d*: <img>).*?(<\/img>\n)"
,
r
"\1\2"
,
prompt
,
)
if
num_matched_images
!=
num_images
:
logger
.
warning
(
"Number of matched image placeholders %s doesn't match the number "
"of expected images %s; check your placeholder formatting."
,
num_matched_images
,
num_images
)
new_prompt_token_ids
=
tokenizer
.
encode
(
new_prompt
)
return
LLMInputs
(
prompt
=
new_prompt
,
prompt_token_ids
=
new_prompt_token_ids
,
multi_modal_data
=
multi_modal_data
)
def
input_mapper_for_qwen
(
ctx
:
InputContext
,
data
:
object
)
->
MultiModalInputs
:
"""Maps the input data to its MultiModalInputs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalInputs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
hf_config
=
ctx
.
get_hf_config
()
if
not
hasattr
(
hf_config
,
"visual"
):
logger
.
warning
(
"Images were provided but this model has no visual config; "
"multimodal inputs will not be forwarded to the model."
)
return
MultiModalInputs
()
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
image_pair_tok
=
tokenizer
.
encode
(
IMG_START
+
IMG_END
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
).
squeeze
()
image_start_id
=
image_pair_tok
[
0
]
image_end_id
=
image_pair_tok
[
-
1
]
if
(
image_start_id
+
1
)
!=
image_end_id
:
raise
ValueError
(
f
"Found image end ID
{
image_end_id
}
, but expected
{
IMG_START
}
+ 1"
)
if
len
(
image_pair_tok
)
!=
(
MAX_QWEN_IMG_TOKENS
+
2
):
raise
ValueError
(
f
"Expected image context length of
{
MAX_QWEN_IMG_TOKENS
}
, "
f
"but got
{
image_pair_tok
-
2
}
"
)
hf_config
=
ctx
.
get_hf_config
()
image_size
=
hf_config
.
visual
[
"image_size"
]
img_emb_size
=
hf_config
.
visual
[
"output_dim"
]
if
isinstance
(
data
,
torch
.
Tensor
):
# It's expected that our values have already been processed
# by the visual transformer; shape is expected to be:
# (# images, 256, hidden_size)
if
len
(
data
.
shape
)
==
2
:
# Assume only one image embed was provided; unsqueeze the extra dim
data
=
data
.
unsqueeze
(
0
)
if
len
(
data
.
shape
)
!=
3
or
data
.
shape
[
1
]
!=
MAX_QWEN_IMG_TOKENS
or
data
.
shape
[
2
]
!=
img_emb_size
:
raise
ValueError
(
"Expected image embeds to be a tensor of shape"
f
"[# images,
{
MAX_QWEN_IMG_TOKENS
}
,
{
img_emb_size
}
], but "
f
"received shape [
{
data
.
shape
}
]"
)
pixel_values
=
data
else
:
transform
=
build_normalization_transform
(
image_size
)
# TODO - handle multiple image inputs once the API is solidified
transformed_images
=
[
transform
(
data
)]
pixel_values
=
torch
.
stack
(
transformed_images
,
dim
=
0
)
return
MultiModalInputs
({
"pixel_values"
:
pixel_values
})
def
build_normalization_transform
(
image_size
:
int
)
->
transforms
.
Compose
:
"""Builds a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
image_size: size of the image to be processed for visual embeddings.
Returns:
Callable transform for normalizing and resizing one RGB image.
"""
return
transforms
.
Compose
([
transforms
.
Resize
((
image_size
,
image_size
),
interpolation
=
InterpolationMode
.
BICUBIC
),
transforms
.
ToTensor
(),
transforms
.
Normalize
(
mean
=
CLIP_MEAN
,
std
=
CLIP_STD
),
])
def
dummy_data_for_qwen
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Tuple
[
SequenceData
,
Optional
[
Dict
]]:
"""Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config.
Args:
ctx: Context of the loaded model.
seq_len: Number of tokens in the text sequence.
mm_counts: multimodal data counts.
Returns:
Tuple containing sequential and multimodal data.
"""
hf_config
=
ctx
.
get_hf_config
()
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if
not
hasattr
(
hf_config
,
"visual"
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
*
seq_len
))
mm_data
=
None
return
seq_data
,
mm_data
# We have a visual component - use images to warm up
num_images
=
mm_counts
[
"image"
]
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt
=
''
.
join
(
[
get_image_text
(
idx
,
False
)
for
idx
in
range
(
1
,
num_images
+
1
)])
toks
=
tokenizer
.
encode
(
image_prompt
,
add_special_tokens
=
False
)
# Make sure we actually get the fixed context size per tok padding
num_pads
=
toks
.
count
(
tokenizer
.
encode
(
IMG_PAD
)[
0
])
if
num_pads
!=
(
num_images
*
MAX_QWEN_IMG_TOKENS
):
raise
ValueError
(
f
"Tokenized dummy data should encode
{
MAX_QWEN_IMG_TOKENS
}
pads"
f
" per image, but got
{
num_pads
}
pads for
{
num_images
}
image(s)"
" in total. Are you using a qwen tokenizer?"
)
# Ensure the number of tokens is at minimum the sequence length provided
if
len
(
toks
)
<
seq_len
:
toks
+=
[
0
]
*
(
seq_len
-
len
(
toks
))
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image
=
Image
.
new
(
"RGB"
,
(
224
,
224
),
color
=
0
)
mm_data
=
{
"image"
:
image
if
num_images
==
1
else
[
image
]
*
num_images
}
return
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
toks
)),
mm_data
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_qwen
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
MAX_QWEN_IMG_TOKENS
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_qwen
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_qwen
)
class
QWenLMHeadModel
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
transformer
=
QWenModel
(
config
,
cache_config
,
quant_config
)
self
.
transformer
=
QWenModel
(
config
,
cache_config
,
quant_config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
...
@@ -278,16 +897,47 @@ class QWenLMHeadModel(nn.Module):
...
@@ -278,16 +897,47 @@ class QWenLMHeadModel(nn.Module):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
def
forward
(
def
_get_image_input_type
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
pixel_values
:
Optional
[
torch
.
Tensor
])
->
Optional
[
QwenImageInputs
]:
positions
:
torch
.
Tensor
,
"""Determines if the provided pixel_values are normalized pixel values
kv_caches
:
List
[
torch
.
Tensor
],
or image embeddings.
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
Args:
)
->
torch
.
Tensor
:
pixel_values: Optional data to processed into visual embeddings.
Returns:
None of the QwenImageInputs type used to determine whether or not
the visual transformer needs to process the pixel_values.
"""
if
pixel_values
is
not
None
and
self
.
transformer
.
visual
is
not
None
:
pixel_values
=
flatten_bn
(
pixel_values
)
if
len
(
pixel_values
.
shape
)
==
3
and
pixel_values
.
shape
[
1
]
==
MAX_QWEN_IMG_TOKENS
and
pixel_values
.
shape
[
2
]
==
self
.
config
.
visual
[
"output_dim"
]:
return
QwenImageEmbeddingInputs
(
type
=
"image_embeds"
,
data
=
pixel_values
,
)
else
:
# If we have the wrong shape, assume we still need to process
return
QwenImagePixelInputs
(
type
=
"pixel_values"
,
data
=
pixel_values
,
)
return
None
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
pixel_values
=
self
.
_get_image_input_type
(
pixel_values
)
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
)
attn_metadata
,
intermediate_tensors
,
pixel_values
)
return
hidden_states
return
hidden_states
def
make_empty_intermediate_tensors
(
def
make_empty_intermediate_tensors
(
...
@@ -349,15 +999,6 @@ class QWenLMHeadModel(nn.Module):
...
@@ -349,15 +999,6 @@ class QWenLMHeadModel(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip loading visual weights to support Qwen-VL models
# in cases with text-only inputs
# TODO: add support for Qwen-VL
if
(
name
not
in
params_dict
and
name
.
startswith
(
"transformer.visual."
)):
print_warning_once
(
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported."
)
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
4851c202
...
@@ -469,7 +469,8 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -469,7 +469,8 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
@@ -490,6 +491,10 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -490,6 +491,10 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
# Skip loading extra bias for GPTQ models.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -500,7 +505,8 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -500,7 +505,8 @@ class Qwen2MoeForCausalLM(nn.Module):
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
continue
# Skip layers on other devices.
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
...
...
vllm/model_executor/models/qwen2_vl.py
0 → 100644
View file @
4851c202
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from
array
import
array
from
functools
import
lru_cache
,
partial
from
typing
import
(
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
TypedDict
,
Union
)
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
PIL
import
Image
from
transformers
import
Qwen2VLConfig
from
transformers.image_utils
import
(
get_image_size
,
infer_channel_dimension_format
,
to_numpy_array
)
from
transformers.models.qwen2_vl.configuration_qwen2_vl
import
(
Qwen2VLVisionConfig
)
from
transformers.models.qwen2_vl.image_processing_qwen2_vl
import
(
make_batched_images
,
make_batched_videos
,
smart_resize
)
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
backend_name_to_enum
,
get_global_forced_attn_backend
)
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
parallel_state
from
vllm.distributed
import
utils
as
dist_utils
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
QuickGELU
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.qwen2
import
Qwen2Model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalInputs
)
from
vllm.multimodal.base
import
MultiModalData
from
vllm.multimodal.image
import
cached_get_image_processor
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
from
vllm.transformers_utils.processor
import
get_processor
logger
=
init_logger
(
__name__
)
# === Vision Inputs === #
class
Qwen2VLImageInputs
(
TypedDict
):
pixel_values
:
torch
.
Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
image_grid_thw
:
torch
.
Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
class
Qwen2VLVideoInputs
(
TypedDict
):
pixel_values_videos
:
torch
.
Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw
:
torch
.
Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
# === Vision Encoder === #
class
Qwen2VisionMLP
(
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
=
None
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
fc1
=
ColumnParallelLinear
(
in_features
,
hidden_features
,
quant_config
=
quant_config
)
self
.
act
=
act_layer
()
self
.
fc2
=
RowParallelLinear
(
hidden_features
,
in_features
,
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x_parallel
,
_
=
self
.
fc1
(
x
)
x_parallel
=
self
.
act
(
x_parallel
)
x
,
_
=
self
.
fc2
(
x_parallel
)
return
x
def
rotate_half
(
x
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
if
not
interleaved
:
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"... d two -> ... (d two)"
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
sin
=
repeat
(
sin
,
"... d -> ... 1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]
],
dim
=-
1
,
)
def
apply_rotary_pos_emb_vision
(
t
:
torch
.
Tensor
,
freqs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
t_
=
t
.
float
()
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
output
=
apply_rotary_emb_torch
(
t_
,
cos
,
sin
).
type_as
(
t
)
return
output
class
Qwen2VisionAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
Optional
[
int
]
=
None
,
num_heads
:
Optional
[
int
]
=
None
,
projection_size
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
# Per attention head and per partition values.
world_size
=
parallel_state
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
dist_utils
.
divide
(
projection_size
,
num_heads
)
self
.
num_attention_heads_per_partition
=
dist_utils
.
divide
(
num_heads
,
world_size
)
self
.
qkv
=
ColumnParallelLinear
(
input_size
=
embed_dim
,
output_size
=
3
*
projection_size
,
quant_config
=
quant_config
)
self
.
proj
=
RowParallelLinear
(
input_size
=
projection_size
,
output_size
=
embed_dim
,
quant_config
=
quant_config
)
# Detect attention implementation.
selected_backend
:
Optional
[
_Backend
]
=
get_global_forced_attn_backend
()
if
selected_backend
is
None
:
backend_by_env_var
:
Optional
[
str
]
=
envs
.
VLLM_ATTENTION_BACKEND
if
backend_by_env_var
is
not
None
:
selected_backend
=
backend_name_to_enum
(
backend_by_env_var
)
if
selected_backend
is
None
:
# For Volta and Turing GPUs, use xformers instead.
device_available
=
current_platform
.
get_device_capability
()[
0
]
>=
8
if
device_available
:
from
transformers.utils
import
is_flash_attn_2_available
if
is_flash_attn_2_available
():
self
.
_use_flash_attn
=
True
else
:
logger
.
warning
(
"Current Qwen2-VL implementation has a bug with "
"`vllm-flash-attn` inside vision module, so we use "
"xformers backend instead. You can run `pip install "
"flash-attn to use flash-attention backend."
)
self
.
_use_flash_attn
=
False
else
:
self
.
_use_flash_attn
=
False
else
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
self
.
_use_flash_attn
=
True
elif
selected_backend
==
_Backend
.
XFORMERS
:
self
.
_use_flash_attn
=
False
else
:
raise
RuntimeError
(
f
"Qwen2-VL does not support
{
selected_backend
}
backend now."
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x
,
_
=
self
.
qkv
(
x
)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
x
=
x
.
view
(
*
new_x_shape
)
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
q
,
k
,
v
=
dist_utils
.
split_tensor_along_last_dim
(
x
,
3
)
batch_size
=
q
.
shape
[
1
]
q
,
k
,
v
=
[
rearrange
(
x
,
"s b ... -> b s ..."
).
contiguous
()
for
x
in
(
q
,
k
,
v
)
]
if
rotary_pos_emb
is
not
None
:
q
=
apply_rotary_pos_emb_vision
(
q
,
rotary_pos_emb
)
k
=
apply_rotary_pos_emb_vision
(
k
,
rotary_pos_emb
)
if
self
.
_use_flash_attn
:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from
flash_attn
import
flash_attn_varlen_func
q
,
k
,
v
=
[
rearrange
(
x
,
"b s ... -> (b s) ..."
)
for
x
in
[
q
,
k
,
v
]]
max_seqlen
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
max
().
item
()
output
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
dropout_p
=
0
,
causal
=
False
)
context_layer
=
rearrange
(
output
,
"(b s) ... -> b s ..."
,
b
=
batch_size
)
else
:
from
xformers
import
ops
as
xops
from
xformers.ops.fmha.attn_bias
import
BlockDiagonalMask
seqlens
=
(
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]).
tolist
()
attn_bias
=
BlockDiagonalMask
.
from_seqlens
(
q_seqlen
=
seqlens
,
kv_seqlen
=
None
)
context_layer
=
xops
.
memory_efficient_attention_forward
(
q
,
k
,
v
,
attn_bias
=
attn_bias
,
p
=
0
,
scale
=
None
)
context_layer
=
rearrange
(
context_layer
,
"b s h d -> s b (h d)"
).
contiguous
()
output
,
_
=
self
.
proj
(
context_layer
)
return
output
class
Qwen2VisionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
mlp_ratio
:
float
,
act_layer
:
Type
[
nn
.
Module
]
=
QuickGELU
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
norm1
=
norm_layer
(
dim
)
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
attn
=
Qwen2VisionAttention
(
embed_dim
=
dim
,
num_heads
=
num_heads
,
projection_size
=
dim
,
quant_config
=
quant_config
)
self
.
mlp
=
Qwen2VisionMLP
(
dim
,
mlp_hidden_dim
,
act_layer
=
act_layer
,
quant_config
=
quant_config
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
Tensor
,
rotary_pos_emb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
x
=
x
+
self
.
mlp
(
self
.
norm2
(
x
))
return
x
class
Qwen2VisionPatchEmbed
(
nn
.
Module
):
def
__init__
(
self
,
patch_size
:
int
=
14
,
temporal_patch_size
:
int
=
2
,
in_chans
:
int
=
3
,
embed_dim
:
int
=
1152
,
)
->
None
:
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
temporal_patch_size
=
temporal_patch_size
self
.
embed_dim
=
embed_dim
kernel_size
=
[
temporal_patch_size
,
patch_size
,
patch_size
]
self
.
proj
=
nn
.
Conv3d
(
in_chans
,
embed_dim
,
kernel_size
=
kernel_size
,
stride
=
kernel_size
,
bias
=
False
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
L
,
C
=
x
.
shape
x
=
x
.
view
(
L
,
-
1
,
self
.
temporal_patch_size
,
self
.
patch_size
,
self
.
patch_size
)
x
=
self
.
proj
(
x
).
view
(
L
,
self
.
embed_dim
)
return
x
class
Qwen2VisionPatchMerger
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
context_dim
:
int
,
norm_layer
:
Type
[
nn
.
Module
]
=
None
,
spatial_merge_size
:
int
=
2
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
context_dim
*
(
spatial_merge_size
**
2
)
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
ln_q
=
norm_layer
(
context_dim
)
self
.
mlp
=
nn
.
ModuleList
([
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
),
nn
.
GELU
(),
RowParallelLinear
(
self
.
hidden_size
,
d_model
,
bias
=
True
,
quant_config
=
quant_config
),
])
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
ln_q
(
x
)
x
=
x
.
view
(
-
1
,
self
.
hidden_size
)
mlp_fc1
,
mlp_act
,
mlp_fc2
=
self
.
mlp
x_parallel
,
_
=
mlp_fc1
(
x
)
x_parallel
=
mlp_act
(
x_parallel
)
out
,
_
=
mlp_fc2
(
x_parallel
)
return
out
class
Qwen2VisionRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
theta
:
float
=
10000.0
)
->
None
:
super
().
__init__
()
self
.
dim
=
dim
self
.
theta
=
theta
inv_freq
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_freqs_cached
=
None
def
update_freqs_cache
(
self
,
seqlen
:
int
)
->
None
:
if
seqlen
>
self
.
_seq_len_cached
:
seqlen
*=
2
self
.
_seq_len_cached
=
seqlen
self
.
inv_freq
=
1.0
/
(
self
.
theta
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
dtype
=
torch
.
float
,
device
=
self
.
inv_freq
.
device
)
/
self
.
dim
))
seq
=
torch
.
arange
(
seqlen
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
seq
,
self
.
inv_freq
)
self
.
_freqs_cached
=
freqs
def
forward
(
self
,
seqlen
:
int
)
->
torch
.
Tensor
:
self
.
update_freqs_cache
(
seqlen
)
return
self
.
_freqs_cached
[:
seqlen
]
class
Qwen2VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
vision_config
:
Qwen2VLVisionConfig
,
norm_eps
:
float
=
1e-6
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
patch_size
:
int
=
vision_config
.
patch_size
temporal_patch_size
:
int
=
vision_config
.
temporal_patch_size
spatial_merge_size
:
int
=
vision_config
.
spatial_merge_size
in_chans
:
int
=
vision_config
.
in_chans
hidden_size
:
int
=
vision_config
.
hidden_size
embed_dim
:
int
=
vision_config
.
embed_dim
depth
:
int
=
vision_config
.
depth
num_heads
:
int
=
vision_config
.
num_heads
mlp_ratio
:
float
=
vision_config
.
mlp_ratio
self
.
spatial_merge_size
=
spatial_merge_size
self
.
patch_embed
=
Qwen2VisionPatchEmbed
(
patch_size
=
patch_size
,
temporal_patch_size
=
temporal_patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
)
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
norm_eps
)
head_dim
=
embed_dim
//
num_heads
self
.
rotary_pos_emb
=
Qwen2VisionRotaryEmbedding
(
head_dim
//
2
)
self
.
blocks
=
nn
.
ModuleList
([
Qwen2VisionBlock
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
)
for
_
in
range
(
depth
)
])
self
.
merger
=
Qwen2VisionPatchMerger
(
d_model
=
hidden_size
,
context_dim
=
embed_dim
,
norm_layer
=
norm_layer
,
quant_config
=
quant_config
,
)
@
property
def
dtype
(
self
)
->
torch
.
dtype
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
dtype
@
property
def
device
(
self
)
->
torch
.
device
:
return
self
.
blocks
[
0
].
mlp
.
fc2
.
weight
.
device
def
rot_pos_emb
(
self
,
grid_thw
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_ids
=
[]
for
t
,
h
,
w
in
grid_thw
:
hpos_ids
=
torch
.
arange
(
h
).
unsqueeze
(
1
).
expand
(
-
1
,
w
)
wpos_ids
=
torch
.
arange
(
w
).
unsqueeze
(
0
).
expand
(
h
,
-
1
)
hpos_ids
=
hpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
).
permute
(
0
,
2
,
1
,
3
).
flatten
()
wpos_ids
=
wpos_ids
.
reshape
(
h
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
w
//
self
.
spatial_merge_size
,
self
.
spatial_merge_size
,
).
permute
(
0
,
2
,
1
,
3
).
flatten
()
pos_ids
.
append
(
torch
.
stack
([
hpos_ids
,
wpos_ids
],
dim
=-
1
).
repeat
(
t
,
1
))
pos_ids
=
torch
.
cat
(
pos_ids
,
dim
=
0
)
max_grid_size
=
grid_thw
[:,
1
:].
max
()
rotary_pos_emb_full
=
self
.
rotary_pos_emb
(
max_grid_size
)
rotary_pos_emb
=
rotary_pos_emb_full
[
pos_ids
].
flatten
(
1
)
return
rotary_pos_emb
def
forward
(
self
,
x
:
torch
.
Tensor
,
grid_thw
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
# patchify
x
=
x
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
x
=
self
.
patch_embed
(
x
)
# compute position embedding
rotary_pos_emb
=
self
.
rot_pos_emb
(
grid_thw
)
# compute cu_seqlens
cu_seqlens
=
torch
.
repeat_interleave
(
grid_thw
[:,
1
]
*
grid_thw
[:,
2
],
grid_thw
[:,
0
]).
cumsum
(
dim
=
0
,
dtype
=
torch
.
int32
)
cu_seqlens
=
F
.
pad
(
cu_seqlens
,
(
1
,
0
),
"constant"
,
0
)
# transformers
x
=
x
.
unsqueeze
(
1
)
for
blk
in
self
.
blocks
:
x
=
blk
(
x
,
cu_seqlens
=
cu_seqlens
,
rotary_pos_emb
=
rotary_pos_emb
)
# adapter
x
=
self
.
merger
(
x
)
return
x
# === Vision input helpers === #
cached_get_processor
=
lru_cache
(
get_processor
)
def
mm_input_mapper_for_qwen2_vl
(
ctx
:
InputContext
,
data
:
MultiModalData
[
object
],
data_type_key
:
str
,
)
->
MultiModalInputs
:
"""Input mapper for Qwen2-VL."""
model_config
=
ctx
.
model_config
image_processor
=
cached_get_image_processor
(
model_config
.
model
,
trust_remote_code
=
model_config
.
trust_remote_code
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
"to process the image object"
)
images
=
None
videos
=
None
if
data_type_key
==
"image"
:
images
=
data
else
:
assert
data_type_key
==
"video"
videos
=
data
try
:
batch_data
=
image_processor
\
.
preprocess
(
images
=
images
,
videos
=
videos
,
return_tensors
=
"pt"
)
\
.
data
except
Exception
:
logger
.
error
(
"Failed to process image (%s)"
,
data
)
raise
return
MultiModalInputs
(
batch_data
)
image_input_mapper_for_qwen2_vl
=
partial
(
mm_input_mapper_for_qwen2_vl
,
data_type_key
=
"image"
)
video_input_mapper_for_qwen2_vl
=
partial
(
mm_input_mapper_for_qwen2_vl
,
data_type_key
=
"video"
)
def
_get_vision_info
(
image_processor
,
height
:
int
,
width
:
int
,
min_pixels
:
int
,
max_pixels
:
int
,
do_resize
:
bool
=
True
,
data_type_key
:
str
=
"image"
,
mm_count
:
int
=
1
,
):
"""Get information (resized height / width and number of vision tokens)
of input image / video frame."""
if
do_resize
:
resized_height
,
resized_width
=
smart_resize
(
height
=
height
,
width
=
width
,
factor
=
image_processor
.
patch_size
*
image_processor
.
merge_size
,
min_pixels
=
min_pixels
,
max_pixels
=
max_pixels
,
)
else
:
resized_height
,
resized_width
=
height
,
width
if
data_type_key
==
"image"
:
grid_t
=
mm_count
else
:
assert
data_type_key
==
"video"
grid_t
=
max
(
mm_count
//
image_processor
.
temporal_patch_size
,
1
)
grid_h
=
resized_height
//
image_processor
.
patch_size
grid_w
=
resized_width
//
image_processor
.
patch_size
vision_tokens
=
grid_t
*
grid_h
*
grid_w
llm_num_vision_tokens
=
(
vision_tokens
//
image_processor
.
merge_size
//
image_processor
.
merge_size
)
return
resized_height
,
resized_width
,
llm_num_vision_tokens
def
_get_max_image_info
(
image_processor
,
data_type_key
:
str
=
"image"
,
mm_count
:
int
=
1
,
):
return
_get_vision_info
(
image_processor
,
height
=
9999999
,
width
=
9999999
,
# Limit min / max pixels.
min_pixels
=
max
(
image_processor
.
min_pixels
,
28
*
28
),
max_pixels
=
min
(
image_processor
.
max_pixels
,
1280
*
28
*
28
),
data_type_key
=
data_type_key
,
mm_count
=
mm_count
,
)
def
get_max_qwen2_vl_mm_tokens
(
ctx
:
InputContext
,
data_type_key
:
str
)
->
int
:
image_processor
=
cached_get_image_processor
(
ctx
.
model_config
.
model
)
max_resized_height
,
max_resized_width
,
max_llm_image_tokens
=
\
_get_max_image_info
(
image_processor
,
data_type_key
=
data_type_key
,
mm_count
=
1
)
return
max_llm_image_tokens
get_max_qwen2_vl_image_tokens
=
partial
(
get_max_qwen2_vl_mm_tokens
,
data_type_key
=
"image"
)
get_max_qwen2_vl_video_tokens
=
partial
(
get_max_qwen2_vl_mm_tokens
,
data_type_key
=
"video"
)
def
dummy_data_for_qwen2_vl
(
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
]
)
->
Tuple
[
SequenceData
,
Optional
[
MultiModalDataDict
]]:
image_processor
=
cached_get_image_processor
(
ctx
.
model_config
.
model
)
num_images
=
mm_counts
[
"image"
]
max_resized_height
,
max_resized_width
,
max_llm_image_tokens
=
\
_get_max_image_info
(
image_processor
,
data_type_key
=
"image"
,
mm_count
=
num_images
)
if
seq_len
-
max_llm_image_tokens
-
2
<
0
:
raise
RuntimeError
(
f
"Qwen2-VL cannot process
{
num_images
}
images in a prompt, "
"please increase max_model_len or reduce image limit by "
"--limit-mm-per-prompt."
)
# Check video counts.
num_videos
=
mm_counts
[
"video"
]
max_resized_height
,
max_resized_width
,
max_llm_video_tokens
=
\
_get_max_image_info
(
image_processor
,
data_type_key
=
"video"
,
mm_count
=
num_videos
)
if
seq_len
-
max_llm_video_tokens
-
2
<
0
:
raise
RuntimeError
(
f
"Qwen2-VL cannot process
{
num_images
}
videos in a prompt, "
"please increase max_model_len or reduce video limit by "
"--limit-mm-per-prompt."
)
hf_config
=
ctx
.
get_hf_config
(
Qwen2VLConfig
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
vision_start_token_id
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
image_token_id
])
*
max_llm_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
hf_config
.
vision_end_token_id
])
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
max_llm_image_tokens
-
2
)
dummy_seqdata
=
SequenceData
(
token_ids
)
dummy_image
=
Image
.
new
(
"RGB"
,
(
max_resized_width
,
max_resized_height
),
color
=
0
)
return
dummy_seqdata
,
{
"image"
:
dummy_image
if
num_images
==
1
else
[
dummy_image
]
*
num_images
}
def
_get_llm_num_vision_tokens
(
mm_inputs
:
list
,
data_type_key
:
str
,
image_processor
,
):
"""Get number of vision tokens of multimodal inputs.
This method is derived from `transformers.models.qwen2_vl.
image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`.
"""
image
=
to_numpy_array
(
mm_inputs
[
0
])
input_data_format
=
infer_channel_dimension_format
(
image
)
height
,
width
=
get_image_size
(
image
,
channel_dim
=
input_data_format
)
_
,
_
,
llm_num_vision_tokens
=
_get_vision_info
(
image_processor
,
height
=
height
,
width
=
width
,
min_pixels
=
image_processor
.
min_pixels
,
max_pixels
=
image_processor
.
max_pixels
,
do_resize
=
image_processor
.
do_resize
,
data_type_key
=
data_type_key
,
mm_count
=
len
(
mm_inputs
),
)
return
llm_num_vision_tokens
def
input_processor_for_qwen2_vl
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
)
->
LLMInputs
:
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
,
None
)
if
multi_modal_data
is
None
:
return
llm_inputs
image_inputs
=
multi_modal_data
.
get
(
"image"
,
None
)
video_inputs
=
multi_modal_data
.
get
(
"video"
,
None
)
processor
=
cached_get_processor
(
ctx
.
model_config
.
model
)
image_processor
=
processor
.
image_processor
hf_config
=
ctx
.
get_hf_config
(
Qwen2VLConfig
)
# To avoid redundant processing of vision objects (resize, rescale, etc.),
# we extract code of calculating number of vision tokens from
# `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
#
# The following code is equivalent to:
# prompt = llm_inputs["prompt"]
# inputs = processor(text=[prompt],
# images=image_inputs,
# videos=video_inputs,
# padding=True,
# return_tensors="pt")
# prompt_token_ids = inputs["input_ids"][0].tolist()
prompt_token_ids
=
llm_inputs
.
get
(
"prompt_token_ids"
,
None
)
if
prompt_token_ids
is
None
:
prompt
=
llm_inputs
[
"prompt"
]
prompt_token_ids
=
processor
.
tokenizer
(
prompt
,
padding
=
True
,
return_tensors
=
None
,
)[
"input_ids"
]
# Expand image pad tokens.
if
image_inputs
is
not
None
:
image_indices
=
[
idx
for
idx
,
token
in
enumerate
(
prompt_token_ids
)
if
token
==
hf_config
.
image_token_id
]
image_inputs
=
make_batched_images
(
image_inputs
)
assert
len
(
image_indices
)
==
len
(
image_inputs
)
prompt_token_ids_with_image
=
[]
for
image_cnt
,
image
in
enumerate
(
image_inputs
):
num_image_tokens
=
_get_llm_num_vision_tokens
(
[
image
],
data_type_key
=
"image"
,
image_processor
=
image_processor
,
)
if
image_cnt
==
0
:
non_image_tokens
=
prompt_token_ids
[:
image_indices
[
image_cnt
]]
else
:
non_image_tokens
=
prompt_token_ids
[
image_indices
[
image_cnt
-
1
]
+
1
:
image_indices
[
image_cnt
]]
prompt_token_ids_with_image
.
extend
(
non_image_tokens
)
prompt_token_ids_with_image
.
extend
(
hf_config
.
image_token_id
for
_
in
range
(
num_image_tokens
))
prompt_token_ids_with_image
.
extend
(
prompt_token_ids
[
image_indices
[
-
1
]
+
1
:])
prompt_token_ids
=
prompt_token_ids_with_image
# Expand video pad tokens.
if
video_inputs
is
not
None
:
video_indices
=
[
idx
for
idx
,
token
in
enumerate
(
prompt_token_ids
)
if
token
==
hf_config
.
video_token_id
]
video_inputs
=
make_batched_videos
(
video_inputs
)
assert
len
(
video_indices
)
==
len
(
video_inputs
)
prompt_token_ids_with_video
=
[]
for
video_cnt
,
video
in
enumerate
(
video_inputs
):
num_video_tokens
=
_get_llm_num_vision_tokens
(
video
,
data_type_key
=
"video"
,
image_processor
=
image_processor
,
)
if
video_cnt
==
0
:
non_video_tokens
=
prompt_token_ids
[:
video_indices
[
video_cnt
]]
else
:
non_video_tokens
=
prompt_token_ids
[
video_indices
[
video_cnt
-
1
]
+
1
:
video_indices
[
video_cnt
]]
prompt_token_ids_with_video
.
extend
(
non_video_tokens
)
prompt_token_ids_with_video
.
extend
(
hf_config
.
video_token_id
for
_
in
range
(
num_video_tokens
))
prompt_token_ids_with_video
.
extend
(
prompt_token_ids
[
video_indices
[
-
1
]
+
1
:])
prompt_token_ids
=
prompt_token_ids_with_video
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
llm_inputs
[
"prompt"
],
multi_modal_data
=
multi_modal_data
,
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
image_input_mapper_for_qwen2_vl
)
@
MULTIMODAL_REGISTRY
.
register_input_mapper
(
"video"
,
video_input_mapper_for_qwen2_vl
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_qwen2_vl_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_max_multimodal_tokens
(
"video"
,
get_max_qwen2_vl_video_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_qwen2_vl
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_qwen2_vl
)
class
Qwen2VLForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
config
:
Qwen2VLConfig
,
multimodal_config
:
MultiModalConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
)
->
None
:
super
().
__init__
()
assert
not
cache_config
.
enable_prefix_caching
,
\
"Qwen2-VL currently does not support prefix caching"
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
visual
=
Qwen2VisionTransformer
(
config
.
vision_config
,
norm_eps
=
getattr
(
config
,
"rms_norm_eps"
,
1e-6
),
# NOTE: Qwen2-VL vision encoder does not support any
# quantization method now.
quant_config
=
None
,
)
self
.
model
=
Qwen2Model
(
config
,
cache_config
,
quant_config
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
model
.
embed_tokens
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
f
"Got type:
{
type
(
mm_input
)
}
"
)
if
isinstance
(
mm_input
,
torch
.
Tensor
):
if
mm_input
.
ndim
==
2
:
return
mm_input
if
mm_input
.
ndim
!=
3
:
raise
ValueError
(
f
"
{
name
}
should be 2D or batched 3D tensor. "
f
"Got ndim:
{
mm_input
.
ndim
}
"
)
return
torch
.
concat
(
list
(
mm_input
))
else
:
return
torch
.
concat
(
mm_input
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Qwen2VLImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_grid_thw
=
kwargs
.
pop
(
"image_grid_thw"
,
None
)
if
pixel_values
is
None
:
return
None
pixel_values
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values
,
"image pixel values"
)
image_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
image_grid_thw
,
"image grid_thw"
)
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of image pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
return
Qwen2VLImageInputs
(
pixel_values
=
pixel_values
,
image_grid_thw
=
image_grid_thw
)
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
Qwen2VLVideoInputs
]:
pixel_values_videos
=
kwargs
.
pop
(
"pixel_values_videos"
,
None
)
video_grid_thw
=
kwargs
.
pop
(
"video_grid_thw"
,
None
)
if
pixel_values_videos
is
None
:
return
None
pixel_values_videos
=
self
.
_validate_and_reshape_mm_tensor
(
pixel_values_videos
,
"video pixel values"
)
video_grid_thw
=
self
.
_validate_and_reshape_mm_tensor
(
video_grid_thw
,
"video grid_thw"
)
return
Qwen2VLVideoInputs
(
pixel_values_videos
=
pixel_values_videos
,
video_grid_thw
=
video_grid_thw
,
)
def
_process_image_input
(
self
,
image_input
:
Qwen2VLImageInputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
[
"pixel_values"
].
type
(
self
.
visual
.
dtype
)
image_embeds
=
self
.
visual
(
pixel_values
,
grid_thw
=
image_input
[
"image_grid_thw"
])
return
image_embeds
def
_process_video_input
(
self
,
video_input
:
Qwen2VLVideoInputs
)
->
torch
.
Tensor
:
pixel_values_videos
=
video_input
[
"pixel_values_videos"
].
type
(
self
.
visual
.
dtype
)
video_embeds
=
self
.
visual
(
pixel_values_videos
,
grid_thw
=
video_input
[
"video_grid_thw"
])
return
video_embeds
def
_merge_multimodal_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
torch
.
Tensor
,
placeholder_token_id
:
int
,
)
->
torch
.
Tensor
:
mask
=
(
input_ids
==
placeholder_token_id
)
inputs_embeds
[
mask
,
:]
=
multimodal_embeddings
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
"""Run forward pass for Qwen2-VL.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
positions: Flattened (concatenated) position ids corresponding to a
batch.
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
opensource models), the shape will be `(3, seq_len)`,
otherwise it will be `(seq_len,).
pixel_values: Pixel values to be fed to a model.
`None` if no images are passed.
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
`None` if no images are passed.
pixel_values_videos: Pixel values of videos to be fed to a model.
`None` if no videos are passed.
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed.
"""
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
video_input
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
if
image_input
is
None
and
video_input
is
None
:
inputs_embeds
=
None
else
:
if
getattr
(
self
.
config
,
"rope_scaling"
,
{}).
get
(
"type"
,
None
)
==
"mrope"
:
assert
positions
.
ndim
==
2
and
positions
.
size
(
0
)
==
3
,
(
"multimodal section rotary embedding requires "
f
"(3, seq_len) positions, but got
{
positions
.
size
()
}
"
)
inputs_embeds
=
self
.
model
.
embed_tokens
(
input_ids
)
if
image_input
is
not
None
:
image_embeds
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
_merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
image_embeds
,
placeholder_token_id
=
self
.
config
.
image_token_id
,
)
if
video_input
is
not
None
:
video_embeds
=
self
.
_process_video_input
(
video_input
)
inputs_embeds
=
self
.
_merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
video_embeds
,
placeholder_token_id
=
self
.
config
.
video_token_id
,
)
input_ids
=
None
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
kv_caches
=
kv_caches
,
attn_metadata
=
attn_metadata
,
inputs_embeds
=
inputs_embeds
,
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
]
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"visual"
in
name
and
"qkv.weight"
in
name
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
,
visual_embed_dim
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
,
visual_embed_dim
)
elif
"visual"
in
name
and
"qkv.bias"
in
name
:
visual_num_heads
=
self
.
config
.
vision_config
.
num_heads
visual_embed_dim
=
self
.
config
.
vision_config
.
embed_dim
head_size
=
visual_embed_dim
//
visual_num_heads
loaded_weight
=
loaded_weight
.
view
(
3
,
visual_num_heads
,
head_size
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
try
:
param
=
params_dict
[
name
]
except
KeyError
:
print
(
params_dict
.
keys
())
raise
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/siglip.py
View file @
4851c202
...
@@ -110,7 +110,7 @@ def input_processor_for_siglip(
...
@@ -110,7 +110,7 @@ def input_processor_for_siglip(
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
num_images
,
image_feature_size
,
hidden_size
=
image_data
.
shape
else
:
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
else
:
else
:
...
@@ -443,27 +443,26 @@ class SiglipVisionTransformer(nn.Module):
...
@@ -443,27 +443,26 @@ class SiglipVisionTransformer(nn.Module):
self
.
config
=
config
self
.
config
=
config
embed_dim
=
config
.
hidden_size
embed_dim
=
config
.
hidden_size
if
(
num_hidden_layers_override
is
None
or
num_hidden_layers_override
==
config
.
num_hidden_layers
):
self
.
need_post_layernorm
=
True
elif
num_hidden_layers_override
>
config
.
num_hidden_layers
:
raise
ValueError
(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers"
)
else
:
self
.
need_post_layernorm
=
False
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
self
.
encoder
=
SiglipEncoder
(
config
,
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
num_hidden_layers_override
=
num_hidden_layers_override
,
)
)
if
self
.
need_post_layernorm
:
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
config
.
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
elif
len
(
self
.
encoder
.
layers
)
==
config
.
num_hidden_layers
:
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
else
:
else
:
self
.
post_layernorm
=
nn
.
Identity
()
# post_layernorm is unused when we extract intermediate features
# In this case, we can skip it to conserve memory
self
.
post_layernorm
=
None
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
config
.
vision_use_head
)
config
.
vision_use_head
)
if
self
.
use_head
:
if
self
.
use_head
:
...
@@ -482,6 +481,9 @@ class SiglipVisionTransformer(nn.Module):
...
@@ -482,6 +481,9 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
if
self
.
post_layernorm
is
None
:
return
encoder_outputs
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
# TODO: add this back when pooled_output is used in inference
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# if self.use_head:
...
@@ -512,8 +514,8 @@ class SiglipVisionModel(nn.Module):
...
@@ -512,8 +514,8 @@ class SiglipVisionModel(nn.Module):
)
)
@
property
@
property
def
need
_post_layernorm
(
self
):
def
_require
_post_layernorm
(
self
)
->
bool
:
return
self
.
vision_model
.
need_
post_layernorm
return
self
.
vision_model
.
post_layernorm
is
not
None
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
return
self
.
vision_model
.
embeddings
.
patch_embedding
...
@@ -529,13 +531,19 @@ class SiglipVisionModel(nn.Module):
...
@@ -529,13 +531,19 @@ class SiglipVisionModel(nn.Module):
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
if
self
.
shard_weight
else
[]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in SiglipVisionModel
# post_layernorm is optional in SiglipVisionModel
if
(
"vision_model.post_layernorm"
in
name
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
need
_post_layernorm
):
and
not
self
.
_require
_post_layernorm
):
continue
continue
# omit layers when num_hidden_layers_override is set
# omit layers when num_hidden_layers_override is set
...
@@ -544,7 +552,16 @@ class SiglipVisionModel(nn.Module):
...
@@ -544,7 +552,16 @@ class SiglipVisionModel(nn.Module):
if
layer_idx
>=
layer_count
:
if
layer_idx
>=
layer_count
:
continue
continue
param
=
params_dict
[
name
]
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
if
weight_name
not
in
name
:
default_weight_loader
)
continue
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
.
replace
(
weight_name
,
param_name
)]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/utils.py
View file @
4851c202
...
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.multimodal.base
import
NestedTensors
from
vllm.multimodal.base
import
NestedTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
...
@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if
name
.
startswith
(
missing_layer_name
):
if
name
.
startswith
(
missing_layer_name
):
return
True
return
True
return
False
return
False
def
make_empty_intermediate_tensors_factory
(
keys
:
List
[
str
],
hidden_size
:
int
):
def
make_empty_intermediate_tensors
(
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
key
:
torch
.
zeros
((
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
for
key
in
keys
})
return
make_empty_intermediate_tensors
vllm/multimodal/base.py
View file @
4851c202
...
@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -79,14 +79,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if
len
(
inputs_list
)
==
0
:
if
len
(
inputs_list
)
==
0
:
return
{}
return
{}
keys
=
inputs_list
[
0
].
keys
()
item_lists
:
Dict
[
str
,
List
[
NestedTensors
]]
=
defaultdict
(
list
)
item_lists
:
Dict
[
str
,
List
[
NestedTensors
]]
=
defaultdict
(
list
)
for
inputs
in
inputs_list
:
for
inputs
in
inputs_list
:
if
inputs
.
keys
()
!=
keys
:
# For models that supports multiple modalities (e.g. Qwen2-VL),
msg
=
f
"Inputs do not share the same keys (
{
keys
}
)"
# different modalities will return different data keys,
raise
ValueError
(
msg
)
# so batch() should skip the same key check.
for
k
,
v
in
inputs
.
items
():
for
k
,
v
in
inputs
.
items
():
item_lists
[
k
].
append
(
v
)
item_lists
[
k
].
append
(
v
)
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment