Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5340a2dc
Unverified
Commit
5340a2dc
authored
Aug 27, 2024
by
zifeitong
Committed by
GitHub
Aug 28, 2024
Browse files
[Model] Add multi-image input support for LLaVA-Next offline inference (#7230)
parent
345be0e2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
174 additions
and
52 deletions
+174
-52
tests/conftest.py
tests/conftest.py
+10
-11
tests/models/test_llava_next.py
tests/models/test_llava_next.py
+80
-13
tests/multimodal/test_utils.py
tests/multimodal/test_utils.py
+34
-1
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+4
-4
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+12
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+2
-2
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+32
-19
No files found.
tests/conftest.py
View file @
5340a2dc
...
...
@@ -41,6 +41,10 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
PromptImageInput
=
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]
PromptAudioInput
=
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
with
open
(
filename
,
"r"
)
as
f
:
...
...
@@ -578,8 +582,7 @@ class VllmRunner:
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
if
images
is
not
None
:
assert
len
(
prompts
)
==
len
(
images
)
...
...
@@ -623,10 +626,8 @@ class VllmRunner:
self
,
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
audios
:
Optional
[
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]]
=
None
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
assert
sampling_params
.
logprobs
is
not
None
...
...
@@ -676,10 +677,8 @@ class VllmRunner:
prompts
:
List
[
str
],
max_tokens
:
int
,
num_logprobs
:
int
,
images
:
Optional
[
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]]
=
None
,
audios
:
Optional
[
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]]
=
None
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]]]:
greedy_logprobs_params
=
SamplingParams
(
temperature
=
0.0
,
...
...
tests/models/test_llava_next.py
View file @
5340a2dc
...
...
@@ -6,24 +6,22 @@ from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
from
vllm.multimodal.utils
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
..conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
..conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
,
_ImageAssets
)
from
.utils
import
check_logprobs_close
pytestmark
=
pytest
.
mark
.
vlm
_PREFACE
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's "
"questions."
)
_LIMIT_IMAGE_PER_PROMPT
=
4
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
f
"
{
_PREFACE
}
USER:
<image>
\n
What's the content of the image?
ASSISTANT:
"
,
"[INST]
<image>
\n
What's the content of the image?
[/INST]
"
,
"cherry_blossom"
:
f
"
{
_PREFACE
}
USER:
<image>
\n
What is the season?
ASSISTANT:
"
,
"[INST]
<image>
\n
What is the season?
[/INST]
"
,
})
models
=
[
"llava-hf/llava-v1.6-
vicuna
-7b-hf"
]
models
=
[
"llava-hf/llava-v1.6-
mistral
-7b-hf"
]
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
...
...
@@ -114,19 +112,43 @@ def run_test(
else
:
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
_run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
)
def
_run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
inputs
:
List
[
Tuple
[
List
[
str
],
PromptImageInput
]],
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_model_len
=
40
96
,
max_model_len
=
102
40
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
enforce_eager
=
True
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs
_per_image
for
prompts
,
images
in
inputs
]
with
hf_runner
(
model
,
dtype
=
dtype
,
...
...
@@ -136,7 +158,7 @@ def run_test(
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs
_per_image
for
prompts
,
images
in
inputs
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
...
...
@@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models_multiple_image_inputs
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"[INST] <image><image>
\n
Describe 2 images. [/INST]"
,
"[INST] <image><image>
\n
Describe 2 images. [/INST]"
,
"[INST] <image><image><image><image>
\n
Describe 4 images. [/INST]"
,
"[INST] <image>
\n
What is the season? [/INST]"
],
[
[
stop_sign
,
cherry_blossom
],
# Images with different sizes and aspect-ratios
[
rescale_image_size
(
stop_sign
,
0.1
),
stop_sign
,
],
[
stop_sign
,
rescale_image_size
(
stop_sign
,
0.25
),
cherry_blossom
.
resize
((
183
,
488
)),
cherry_blossom
.
resize
((
488
,
183
))
],
cherry_blossom
,
])]
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
tests/multimodal/test_utils.py
View file @
5340a2dc
...
...
@@ -6,8 +6,10 @@ from typing import Dict, Tuple
import
numpy
as
np
import
pytest
from
PIL
import
Image
from
transformers
import
AutoConfig
,
AutoTokenizer
from
vllm.multimodal.utils
import
async_fetch_image
,
fetch_image
from
vllm.multimodal.utils
import
(
async_fetch_image
,
fetch_image
,
repeat_and_pad_placeholder_tokens
)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
TEST_IMAGE_URLS
=
[
...
...
@@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
data_image_async
=
await
async_fetch_image
(
data_url
)
assert
_image_equals
(
data_image_sync
,
data_image_async
)
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
def
test_repeat_and_pad_placeholder_tokens
(
model
):
config
=
AutoConfig
.
from_pretrained
(
model
)
image_token_id
=
config
.
image_token_index
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
test_cases
=
[
(
"<image>"
,
2
,
"<image><image>"
,
[
32000
,
32000
]),
(
"<image><image>"
,
2
,
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
(
"<image><image>"
,
[
3
,
2
],
"<image><image><image><image><image>"
,
[
32000
,
32000
,
32000
,
32000
,
32000
]),
(
"Image:<image>Image:<image>!"
,
[
3
,
2
],
"Image:<image><image><image>Image:<image><image>!"
,
[
9833
,
28747
,
32000
,
32000
,
32000
,
9833
,
28747
,
32000
,
32000
,
918
]),
(
"<image>"
,
[
3
,
2
],
"<image><image><image>"
,
[
32000
,
32000
,
32000
]),
]
for
prompt
,
repeat_count
,
expected_prompt
,
expected_token_ids
in
test_cases
:
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
tokenizer
=
tokenizer
,
prompt
=
prompt
,
prompt_token_ids
=
tokenizer
.
encode
(
prompt
,
add_special_tokens
=
False
),
placeholder_token_id
=
image_token_id
,
repeat_count
=
repeat_count
,
)
assert
new_prompt
==
expected_prompt
assert
new_token_ids
==
expected_token_ids
vllm/model_executor/models/clip.py
View file @
5340a2dc
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
...
...
@@ -84,7 +84,7 @@ def input_processor_for_clip(
llm_inputs
:
LLMInputs
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
...
vllm/model_executor/models/llava_next.py
View file @
5340a2dc
...
...
@@ -19,6 +19,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
dummy_seq_data_for_clip
,
get_clip_image_feature_size
,
...
...
@@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
input_height
=
height
,
input_width
=
width
,
)
elif
is_list_of
(
image_data
,
Image
.
Image
):
image_feature_size
=
[
get_llava_next_image_feature_size
(
hf_config
,
input_height
=
img
.
height
,
input_width
=
img
.
width
)
for
img
in
image_data
]
elif
isinstance
(
image_data
,
torch
.
Tensor
):
image_feature_size
=
image_data
.
shape
[
0
]
else
:
...
...
@@ -425,7 +433,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self
.
config
.
image_grid_pinpoints
,
self
.
config
.
vision_config
.
image_size
,
)
other_patch_embeds
=
other_patch_embeds
\
num_patches
=
num_patch_height
*
num_patch_width
# Image patches might be padded for batch processing
other_patch_embeds
=
other_patch_embeds
[:
num_patches
]
\
.
view
(
num_patch_height
,
num_patch_width
,
height
,
width
,
-
1
)
if
"unpad"
in
strategy
:
...
...
@@ -496,7 +507,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self
,
image_input
:
LlavaNextImageInputs
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
image_input
[
"type"
]
==
"image_embeds"
:
return
[
image_input
[
"data"
]]
...
...
vllm/model_executor/models/siglip.py
View file @
5340a2dc
...
...
@@ -3,7 +3,7 @@ within a vision language model."""
import
math
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
PIL
import
Image
...
...
@@ -93,7 +93,7 @@ def input_processor_for_siglip(
llm_inputs
:
LLMInputs
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
...
vllm/multimodal/utils.py
View file @
5340a2dc
...
...
@@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids
:
List
[
int
],
*
,
placeholder_token_id
:
int
,
repeat_count
:
int
=
1
,
repeat_count
:
Union
[
int
,
List
[
int
]]
,
pad_token_left
:
Optional
[
int
]
=
None
,
pad_token_right
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
if
isinstance
(
repeat_count
,
int
):
repeat_count
=
[
repeat_count
]
if
prompt
is
None
:
new_prompt
=
None
else
:
...
...
@@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer
.
decode
(
pad_token_left
))
pad_token_str_right
=
(
None
if
pad_token_right
is
None
else
tokenizer
.
decode
(
pad_token_right
))
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
placeholder_token_count
=
prompt
.
count
(
placeholder_token_str
)
# This is an arbitrary number to distinguish between the two cases
...
...
@@ -216,26 +212,43 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens."
,
placeholder_token_str
)
elif
placeholder_token_count
>
1
:
logger
.
warning
(
"Multiple multi-modal input is not supported yet, "
"so any extra placeholder tokens will be treated "
"as plain text."
)
if
placeholder_token_count
<
len
(
repeat_count
):
logger
.
warning
(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text"
)
repeat_count
=
repeat_count
[:
placeholder_token_count
]
prompt_parts
=
prompt
.
split
(
placeholder_token_str
,
maxsplit
=
len
(
repeat_count
))
new_prompt
=
""
for
i
,
repeat_count_item
in
enumerate
(
repeat_count
):
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count_item
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt
=
prompt
.
replace
(
placeholder_token_str
,
replacement_str
,
1
)
new_prompt
+=
prompt_parts
[
i
]
+
replacement_str
new_prompt
+=
prompt_parts
[
-
1
]
new_token_ids
:
List
[
int
]
=
[]
placeholder_token_idx
=
0
for
i
,
token
in
enumerate
(
prompt_token_ids
):
if
token
==
placeholder_token_id
:
replacement_ids
=
repeat_and_pad_token
(
placeholder_token_id
,
repeat_count
=
repeat_count
,
repeat_count
=
repeat_count
[
placeholder_token_idx
]
,
pad_token_left
=
pad_token_left
,
pad_token_right
=
pad_token_right
,
)
new_token_ids
.
extend
(
replacement_ids
)
placeholder_token_idx
+=
1
# No need to further scan the list since we only replace once
# No need to further scan the list since we replaced all tokens
if
placeholder_token_idx
>=
len
(
repeat_count
):
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
1
:])
break
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment