Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
25e778aa
Unverified
Commit
25e778aa
authored
Jul 22, 2024
by
Isotr0py
Committed by
GitHub
Jul 21, 2024
Browse files
[Model] Refactor and decouple phi3v image embedding (#6621)
parent
b6df37f9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
119 deletions
+118
-119
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+118
-119
No files found.
vllm/model_executor/models/phi3v.py
View file @
25e778aa
...
...
@@ -43,6 +43,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
input_processor_for_clip
)
from
.interfaces
import
SupportsVision
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -71,9 +72,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
class
Phi3ImageEmbeddingBase
(
nn
.
Module
):
def
__init__
(
self
,
wte
=
None
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
wte
=
wte
self
.
layer_idx
:
int
self
.
type_feature
:
str
self
.
img_processor
:
CLIPVisionModel
...
...
@@ -100,10 +100,9 @@ class Phi3ImageEmbeddingBase(nn.Module):
class
Phi3HDImageEmbedding
(
Phi3ImageEmbeddingBase
):
"""Phi3 Image embedding with HD transform."""
def
__init__
(
self
,
config
:
PretrainedConfig
,
wte
=
None
)
->
None
:
super
().
__init__
(
wte
)
def
__init__
(
self
,
config
:
PretrainedConfig
)
->
None
:
super
().
__init__
()
self
.
image_token_id
=
_IMAGE_TOKEN_ID
# n_embed or hidden_size
hidden_size
=
config
.
n_embd
if
hasattr
(
config
,
'n_embd'
)
else
config
.
hidden_size
...
...
@@ -149,118 +148,115 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
nn
.
Linear
(
dim_projection
,
dim_projection
)])
self
.
img_projection
=
nn
.
Sequential
(
*
layers
)
self
.
vocab_size
=
config
.
vocab_size
self
.
type_feature
=
config
.
img_processor
.
get
(
'type_feature'
,
'patch'
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
pixel_values
:
torch
.
FloatTensor
,
def
forward
(
self
,
pixel_values
:
torch
.
FloatTensor
,
image_sizes
:
torch
.
Tensor
)
->
torch
.
FloatTensor
:
"""process and merge text embeddings with image embeddings."""
# (batch_size, max_num_crops, 3, height, width)
img_embeds
=
pixel_values
# (batch_size, 2)
img_sizes
=
image_sizes
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
positions
=
torch
.
nonzero
(
input_ids
==
self
.
image_token_id
)
select
=
False
"""
process image and return vision embeddings.
pixel_values: (num_images, num_crops, c, h, w)
output: (num_images, num_img_tokens, hidden_size)
"""
num_images
,
num_crops
,
c
,
h
,
w
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
flatten
(
0
,
1
)
img_features
=
self
.
get_img_features
(
pixel_values
)
img_features
=
img_features
.
reshape
(
num_images
,
num_crops
,
-
1
,
self
.
image_dim_out
)
image_features_proj
=
self
.
hd_feature_transform
(
img_features
,
image_sizes
)
return
image_features_proj
def
hd_feature_transform
(
self
,
image_features
,
image_sizes
):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
assert
(
self
.
hd_transform_order
==
'sub_glb'
),
f
'hd_transform_order `
{
self
.
hd_transform_order
}
` not implemented'
if
isinstance
(
self
.
img_projection
,
nn
.
Sequential
):
target_device
=
self
.
img_projection
[
0
].
bias
.
device
target_dtype
=
self
.
img_projection
[
0
].
bias
.
dtype
else
:
# It's a single nn.Linear layer
target_device
=
self
.
img_projection
.
bias
.
device
target_dtype
=
self
.
img_projection
.
bias
.
dtype
global_image_features
=
image_features
[:,
0
]
# (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd
=
self
.
reshape_hd_patches_2x2merge
(
global_image_features
,
1
,
1
)
global_image_features_hd_newline
=
self
.
add_image_newline
(
global_image_features_hd
)
all_image_embeddings
=
[]
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for
i
,
img_size
in
enumerate
(
image_sizes
):
h
,
w
=
img_size
h_crop
=
h
//
336
w_crop
=
w
//
336
num_crops
=
h_crop
*
w_crop
# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features
=
image_features
[
i
,
1
:
1
+
num_crops
]
sub_image_features_hd
=
self
.
reshape_hd_patches_2x2merge
(
sub_image_features
,
h_crop
,
w_crop
)
sub_image_features_hd_newline
=
self
.
add_image_newline
(
sub_image_features_hd
)
# [sub features, separator, global features]
all_image_embeddings
.
append
(
torch
.
cat
([
sub_image_features_hd_newline
.
squeeze
(
0
),
# (h_crop*12*(w_crop*12+1), 4096)
self
.
glb_GN
.
squeeze
(
0
),
global_image_features_hd_newline
[
i
],
]))
image_features_proj
=
self
.
img_projection
(
torch
.
stack
(
all_image_embeddings
).
to
(
target_device
,
target_dtype
)
)
# (num_images, (h_crop*12*(w_crop*12+1)+1), hidden_size)
return
image_features_proj
def
reshape_hd_patches_2x2merge
(
self
,
image_features
,
h_crop
,
w_crop
):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096)
where h_crop*w_crop == num_crops
"""
N
,
L
,
C
=
image_features
.
shape
assert
L
==
576
and
C
==
1024
and
N
%
(
h_crop
*
w_crop
)
==
0
num_images
=
N
//
(
h_crop
*
w_crop
)
H
=
int
(
L
**
0.5
)
image_features_hd
=
(
image_features
.
reshape
(
N
,
H
,
H
,
C
)
# N, 24, 24, 1024
.
reshape
(
N
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
# N, 12, 2, 12, 2, 1024
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
# N, 12, 12, 2, 2, 1024
.
reshape
(
N
,
-
1
,
4
*
C
)
# N, 144, 4096
.
reshape
(
num_images
,
h_crop
,
w_crop
,
H
//
2
,
H
//
2
,
-
1
)
# n_img, h_crop, w_crop, 12, 12, 4096
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
# n_img, h_crop, 12, w_crop, 12, 4096
.
reshape
(
num_images
,
h_crop
*
H
//
2
,
w_crop
*
H
//
2
,
4
*
C
)
# n_img, h_crop*12, w_crop*12, 4096
)
return
image_features_hd
if
len
(
positions
.
tolist
())
>
0
:
# if self.use_hd_transform and img_sizes:
# img_embeds: (num_images, max_num_crops, 3, H, W)
# img_sizes: (num_images, 2).view(1, -1)
bs
=
img_embeds
.
shape
[
0
]
# Nx(HW)xC
img_features
=
self
.
get_img_features
(
img_embeds
.
flatten
(
0
,
1
))
base_feat_height
=
base_feat_width
=
int
(
img_features
.
shape
[
1
]
**
0.5
)
# bs x max_num_crops x (24x24) x C
img_features
=
img_features
.
view
(
bs
,
-
1
,
base_feat_height
*
base_feat_width
,
self
.
image_dim_out
)
C
=
self
.
image_dim_out
H
=
base_feat_height
output_imgs
=
[]
output_len
=
[]
for
_bs
in
range
(
bs
):
h
,
w
=
img_sizes
[
_bs
]
h
=
h
//
336
w
=
w
//
336
B_
=
h
*
w
# 1 x (24x24) x 1024
global_img_feature
=
img_features
[
_bs
,
:
1
]
# 1 x 12 x 12 x 4096
glb_img
=
global_img_feature
\
.
reshape
(
1
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
\
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
\
.
reshape
(
1
,
H
//
2
,
H
//
2
,
4
*
C
)
temp_glb_GN
=
self
.
sub_GN
.
repeat
(
1
,
H
//
2
,
1
,
1
)
# 1 x 156 x 4096
glb_img
=
torch
.
cat
([
glb_img
,
temp_glb_GN
],
dim
=
2
).
reshape
(
1
,
-
1
,
4
*
C
)
# (max_num_crops-1) x (12x12) x C
sub_img
=
img_features
[
_bs
,
1
:]
# 16x574x1024
# get rid of padding sub_img
sub_img
=
sub_img
[:
B_
]
sub_img
=
sub_img
.
reshape
(
B_
,
H
//
2
,
2
,
H
//
2
,
2
,
C
)
\
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
reshape
(
B_
,
-
1
,
4
*
C
)
sub_img
=
sub_img
.
reshape
(
1
,
h
,
w
,
12
,
12
,
-
1
)
\
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
)
\
.
reshape
(
1
,
h
*
12
,
w
*
12
,
4
*
C
)
temp_sub_GN
=
self
.
sub_GN
.
repeat
(
1
,
h
*
12
,
1
,
1
)
sub_img
=
torch
.
cat
([
sub_img
,
temp_sub_GN
],
dim
=
2
).
reshape
(
1
,
-
1
,
4
*
C
)
# (1, num_img_tokens, 1024*4)
# glb + sub
if
self
.
hd_transform_order
==
'glb_sub'
:
output_imgs
.
append
(
torch
.
cat
([
glb_img
,
self
.
glb_GN
,
sub_img
],
dim
=
1
))
elif
self
.
hd_transform_order
==
'sub_glb'
:
output_imgs
.
append
(
torch
.
cat
([
sub_img
,
self
.
glb_GN
,
glb_img
],
dim
=
1
))
temp_len
=
int
((
h
*
w
+
1
)
*
144
+
1
+
(
h
+
1
)
*
12
)
output_len
.
append
(
temp_len
)
num_img_tokens
=
output_len
img_set_tensor
=
[]
for
_output_img
in
output_imgs
:
img_feature_proj
=
self
.
img_projection
(
_output_img
.
to
(
target_dtype
))
img_set_tensor
.
append
(
img_feature_proj
)
select
=
True
input_ids
.
clamp_min_
(
0
).
clamp_max_
(
self
.
vocab_size
)
hidden_states
=
self
.
wte
(
input_ids
)
if
select
:
idx
=
0
for
i
,
cnt
in
enumerate
(
num_img_tokens
):
hidden_states
[
positions
[
idx
,
0
],
positions
[
idx
,
1
]:
positions
[
idx
,
1
]
+
cnt
]
=
(
img_set_tensor
[
i
].
to
(
hidden_states
.
dtype
))
idx
+=
cnt
return
hidden_states
.
squeeze
(
0
)
def
add_image_newline
(
self
,
image_features_hd
):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images
,
h
,
w
,
hid_dim
=
image_features_hd
.
shape
# add the newline token to the HD image feature patches
newline_embeddings
=
self
.
sub_GN
.
expand
(
num_images
,
h
,
-
1
,
-
1
)
# (n_img, h, 1, hid_dim)
image_features_hd_newline
=
torch
.
cat
(
[
image_features_hd
,
newline_embeddings
],
dim
=
2
).
reshape
(
num_images
,
-
1
,
hid_dim
)
return
image_features_hd_newline
class
Phi3VImagePixelInputs
(
TypedDict
):
...
...
@@ -458,12 +454,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self
.
config
=
config
self
.
multimodal_config
=
multimodal_config
self
.
image_token_id
=
_IMAGE_TOKEN_ID
self
.
model
=
LlamaModel
(
config
,
cache_config
,
quant_config
)
# TODO: Optionally initializes this for supporting embeddings.
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
config
,
self
.
model
.
embed_tokens
)
self
.
vision_embed_tokens
=
Phi3HDImageEmbedding
(
config
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
)
...
...
@@ -530,9 +526,12 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
inputs_embeds
=
self
.
vision_embed_tokens
(
input_ids
,
image_input
[
"data"
],
image_input
[
"image_sizes"
])
vision_embeddings
=
self
.
vision_embed_tokens
(
image_input
[
"data"
],
image_input
[
"image_sizes"
])
inputs_embeds
=
self
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_vision_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
image_token_id
)
input_ids
=
None
else
:
inputs_embeds
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment