Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8aaf3d53
Unverified
Commit
8aaf3d53
authored
Aug 25, 2024
by
Isotr0py
Committed by
GitHub
Aug 25, 2024
Browse files
[Model][VLM] Support multi-images inputs for Phi-3-vision models (#7783)
parent
80162c44
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
168 additions
and
29 deletions
+168
-29
tests/models/test_phi3v.py
tests/models/test_phi3v.py
+111
-0
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+57
-29
No files found.
tests/models/test_phi3v.py
View file @
8aaf3d53
...
@@ -21,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
...
@@ -21,6 +21,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"cherry_blossom"
:
"cherry_blossom"
:
"<|user|>
\n
<|image_1|>
\n
What is the season?<|end|>
\n
<|assistant|>
\n
"
,
"<|user|>
\n
<|image_1|>
\n
What is the season?<|end|>
\n
<|assistant|>
\n
"
,
})
})
HF_MULTIIMAGE_IMAGE_PROMPT
=
"<|user|>
\n
<|image_1|>
\n
<|image_2|>
\n
Describe these images.<|end|>
\n
<|assistant|>
\n
"
# noqa: E501
models
=
[
"microsoft/Phi-3.5-vision-instruct"
]
models
=
[
"microsoft/Phi-3.5-vision-instruct"
]
...
@@ -184,3 +185,113 @@ def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
...
@@ -184,3 +185,113 @@ def test_regression_7840(hf_runner, vllm_runner, image_assets, model,
num_logprobs
=
10
,
num_logprobs
=
10
,
tensor_parallel_size
=
1
,
tensor_parallel_size
=
1
,
)
)
def
run_multi_image_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
images
:
List
[
Image
.
Image
],
model
:
str
,
*
,
size_factors
:
List
[
float
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
inputs_per_case
=
[
([
HF_MULTIIMAGE_IMAGE_PROMPT
for
_
in
size_factors
],
[[
rescale_image_size
(
image
,
factor
)
for
image
in
images
]
for
factor
in
size_factors
])
]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
max_model_len
=
4096
,
max_num_seqs
=
1
,
limit_mm_per_prompt
=
{
"image"
:
len
(
images
)},
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_case
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs_per_case
]
hf_model_kwargs
=
{
"_attn_implementation"
:
"eager"
}
with
hf_runner
(
model
,
dtype
=
dtype
,
model_kwargs
=
hf_model_kwargs
)
as
hf_model
:
eos_token_id
=
hf_model
.
processor
.
tokenizer
.
eos_token_id
hf_outputs_per_case
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
,
eos_token_id
=
eos_token_id
)
for
prompts
,
images
in
inputs_per_case
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_case
,
vllm_outputs_per_case
):
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
model
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
target_dtype
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_multi_images_models
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
size_factors
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
run_multi_image_test
(
hf_runner
,
vllm_runner
,
[
asset
.
pil_image
for
asset
in
image_assets
],
model
,
size_factors
=
size_factors
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
vllm/model_executor/models/phi3v.py
View file @
8aaf3d53
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
itertools
import
re
import
re
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
...
@@ -37,11 +38,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
...
@@ -37,11 +38,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
,
repeat_and_pad_token
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
is_list_of
from
.clip
import
(
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
from
.clip
import
dummy_image_for_clip
,
dummy_seq_data_for_clip
input_processor_for_clip
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
merge_multimodal_embeddings
...
@@ -400,9 +401,20 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -400,9 +401,20 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
image_data
=
multi_modal_data
[
"image"
]
image_data
=
multi_modal_data
[
"image"
]
if
isinstance
(
image_data
,
Image
.
Image
):
if
isinstance
(
image_data
,
Image
.
Image
):
w
,
h
=
image_data
.
size
w
,
h
=
image_data
.
size
image_feature_size
=
get_phi3v_image_feature_size
(
hf_config
,
image_feature_size
=
[
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_width
=
w
,
input_height
=
h
)
input_height
=
h
)
]
image_data
=
[
image_data
]
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[]
for
image
in
image_data
:
w
,
h
=
image
.
size
image_feature_size
.
append
(
get_phi3v_image_feature_size
(
hf_config
,
input_width
=
w
,
input_height
=
h
))
elif
isinstance
(
image_data
,
torch
.
Tensor
):
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
image_feature_size
=
image_data
.
shape
[
0
]
else
:
else
:
...
@@ -410,45 +422,61 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -410,45 +422,61 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
prompt
=
llm_inputs
.
get
(
"prompt"
)
prompt
=
llm_inputs
.
get
(
"prompt"
)
if
prompt
is
None
:
if
prompt
is
None
:
image_idx
=
[]
new_prompt
=
None
new_prompt
=
None
else
:
else
:
image_idx
=
sorted
(
map
(
int
,
re
.
findall
(
r
"<\|image_(\d+)\|>+"
,
prompt
)))
if
prompt
.
count
(
"<|image|>"
)
>
0
:
if
prompt
.
count
(
"<|image|>"
)
>
0
:
logger
.
warning
(
"Please follow the prompt format that is "
logger
.
warning
(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens."
)
"repeating <|image|> tokens."
)
elif
len
(
re
.
findall
(
r
"(<\|image_\d+\|>)+"
,
prompt
))
>
1
:
elif
(
num_image_tags
:
=
len
(
image_idx
))
>
1
:
logger
.
warning
(
"Multiple image input is not supported yet, "
assert
num_image_tags
==
len
(
"so any extra image tokens will be treated "
image_data
),
"The count of image_placeholder not match image's"
"as plain text."
)
new_prompt
=
prompt
new_prompt
=
prompt
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
]
prompt_token_ids
=
llm_inputs
[
"prompt_token_ids"
].
copy
()
image_1_token_ids
=
_get_image_placeholder_token_ids
(
model_config
,
idx
=
1
)
# masked place_holder with image token id
for
idx
in
image_idx
:
image_token_ids
=
_get_image_placeholder_token_ids
(
model_config
,
idx
=
idx
)
for
i
in
range
(
len
(
prompt_token_ids
)
-
len
(
image_token_ids
)
+
1
):
if
prompt_token_ids
[
i
:
i
+
len
(
image_token_ids
)]
==
image_token_ids
:
prompt_token_ids
[
i
:
i
+
len
(
image_token_ids
)]
=
[
_IMAGE_TOKEN_ID
]
*
len
(
image_token_ids
)
break
new_token_ids
:
List
[
int
]
=
[]
# merge consecutive tag ids
for
i
in
range
(
len
(
prompt_token_ids
)
-
len
(
image_1_token_ids
)
+
1
):
merged_token_ids
:
List
[
int
]
=
[]
if
prompt_token_ids
[
i
:
i
+
len
(
image_1_token_ids
)]
==
image_1_token_ids
:
for
is_placeholder
,
token_ids
in
itertools
.
groupby
(
new_token_ids
.
append
(
_IMAGE_TOKEN_ID
)
prompt_token_ids
,
lambda
x
:
x
==
_IMAGE_TOKEN_ID
):
if
is_placeholder
:
merged_token_ids
.
append
(
_IMAGE_TOKEN_ID
)
else
:
merged_token_ids
.
extend
(
list
(
token_ids
))
# No need to further scan the list since we only replace once
# TODO: Move this to utils or integrate with clip.
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
len
(
image_1_token_ids
):])
new_token_ids
:
List
[
int
]
=
[]
break
placeholder_idx
=
0
while
merged_token_ids
:
token_id
=
merged_token_ids
.
pop
(
0
)
if
token_id
==
_IMAGE_TOKEN_ID
:
new_token_ids
.
extend
(
repeat_and_pad_token
(
_IMAGE_TOKEN_ID
,
repeat_count
=
image_feature_size
[
placeholder_idx
],
))
placeholder_idx
+=
1
else
:
else
:
new_token_ids
.
append
(
prompt_
token_id
s
[
i
]
)
new_token_ids
.
append
(
token_id
)
# NOTE: Create a defensive copy of the original inputs
# NOTE: Create a defensive copy of the original inputs
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
)
multi_modal_data
=
multi_modal_data
)
return
llm_inputs
return
input_processor_for_clip
(
model_config
,
CLIP_VIT_LARGE_PATCH14_336_CONFIG
,
llm_inputs
,
image_token_id
=
_IMAGE_TOKEN_ID
,
image_feature_size_override
=
image_feature_size
,
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment