Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
831540cf
"vscode:/vscode.git/clone" did not exist on "4f8f47e87e4f65195cf77b0de93feb63fc5a5b2f"
Unverified
Commit
831540cf
authored
Oct 23, 2024
by
Cyrus Leung
Committed by
GitHub
Oct 23, 2024
Browse files
[Model] Support E5-V (#9576)
parent
29061ed9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
532 additions
and
90 deletions
+532
-90
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+14
-0
examples/offline_inference_vision_language.py
examples/offline_inference_vision_language.py
+3
-3
examples/offline_inference_vision_language_embedding.py
examples/offline_inference_vision_language_embedding.py
+169
-21
examples/offline_inference_vision_language_multi_image.py
examples/offline_inference_vision_language_multi_image.py
+4
-3
tests/conftest.py
tests/conftest.py
+36
-24
tests/models/embedding/utils.py
tests/models/embedding/utils.py
+2
-1
tests/models/embedding/vision_language/test_llava_next.py
tests/models/embedding/vision_language/test_llava_next.py
+135
-0
tests/models/embedding/vision_language/test_phi3v.py
tests/models/embedding/vision_language/test_phi3v.py
+77
-16
vllm/model_executor/models/llava_next.py
vllm/model_executor/models/llava_next.py
+22
-11
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+0
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+69
-9
No files found.
docs/source/models/supported_models.rst
View file @
831540cf
...
...
@@ -334,6 +334,14 @@ The following modalities are supported depending on the model:
- **V**\ ideo
- **A**\ udio
Any combination of modalities joined by :code:`+` are supported.
- e.g.: :code:`T + I` means that the model supports text-only, image-only, and text-with-image inputs.
On the other hand, modalities separated by :code:`/` are mutually exclusive.
- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs.
.. _supported_vlms:
Text Generation
...
...
@@ -484,6 +492,12 @@ Multimodal Embedding
- Example HF Models
- :ref:`LoRA <lora>`
- :ref:`PP <distributed_serving>`
* - :code:`LlavaNextForConditionalGeneration`
- LLaVA-NeXT-based
- T / I
- :code:`royokong/e5-v`
-
- ✅︎
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision-based
- T + I
...
...
examples/offline_inference_vision_language.py
View file @
831540cf
"""
This example shows how to use vLLM for running offline inference
with
the correct prompt format on vision language models.
This example shows how to use vLLM for running offline inference
with
the correct prompt format on vision language models
for text generation
.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
...
...
@@ -450,7 +450,7 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Demo on using vLLM for offline inference with '
'vision language models'
)
'vision language models
for text generation
'
)
parser
.
add_argument
(
'--model-type'
,
'-m'
,
type
=
str
,
...
...
examples/offline_inference_vision_language_embedding.py
View file @
831540cf
"""
This example shows how to use vLLM for running offline inference with
the correct prompt format on vision language models for multimodal embedding.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from
argparse
import
Namespace
from
typing
import
Literal
,
NamedTuple
,
Optional
,
TypedDict
,
Union
,
get_args
from
PIL.Image
import
Image
from
vllm
import
LLM
from
vllm.assets.image
import
ImageAsset
image
=
ImageAsset
(
"cherry_blossom"
).
pil_image
.
convert
(
"RGB"
)
prompt
=
"<|image_1|> Represent the given image with the following question: What is in the image"
# noqa: E501
# Create an LLM.
llm
=
LLM
(
model
=
"TIGER-Lab/VLM2Vec-Full"
,
task
=
"embedding"
,
trust_remote_code
=
True
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
mm_processor_kwargs
=
{
"num_crops"
:
16
},
)
# Generate embedding. The output is a list of EmbeddingRequestOutputs.
outputs
=
llm
.
encode
({
"prompt"
:
prompt
,
"multi_modal_data"
:
{
"image"
:
image
}})
# Print the outputs.
for
output
in
outputs
:
print
(
output
.
outputs
.
embedding
)
# list of 3072 floats
from
vllm.multimodal.utils
import
fetch_image
from
vllm.utils
import
FlexibleArgumentParser
class
TextQuery
(
TypedDict
):
modality
:
Literal
[
"text"
]
text
:
str
class
ImageQuery
(
TypedDict
):
modality
:
Literal
[
"image"
]
image
:
Image
class
TextImageQuery
(
TypedDict
):
modality
:
Literal
[
"text+image"
]
text
:
str
image
:
Image
QueryModality
=
Literal
[
"text"
,
"image"
,
"text+image"
]
Query
=
Union
[
TextQuery
,
ImageQuery
,
TextImageQuery
]
class
ModelRequestData
(
NamedTuple
):
llm
:
LLM
prompt
:
str
image
:
Optional
[
Image
]
def
run_e5_v
(
query
:
Query
):
llama3_template
=
'<|start_header_id|>user<|end_header_id|>
\n\n
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
\n
'
# noqa: E501
if
query
[
"modality"
]
==
"text"
:
text
=
query
[
"text"
]
prompt
=
llama3_template
.
format
(
f
"
{
text
}
\n
Summary above sentence in one word: "
)
image
=
None
elif
query
[
"modality"
]
==
"image"
:
prompt
=
llama3_template
.
format
(
"<image>
\n
Summary above image in one word: "
)
image
=
query
[
"image"
]
else
:
modality
=
query
[
'modality'
]
raise
ValueError
(
f
"Unsupported query modality: '
{
modality
}
'"
)
llm
=
LLM
(
model
=
"royokong/e5-v"
,
task
=
"embedding"
,
max_model_len
=
4096
,
)
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
image
=
image
,
)
def
run_vlm2vec
(
query
:
Query
):
if
query
[
"modality"
]
==
"text"
:
text
=
query
[
"text"
]
prompt
=
f
"Find me an everyday image that matches the given caption:
{
text
}
"
# noqa: E501
image
=
None
elif
query
[
"modality"
]
==
"image"
:
prompt
=
"<|image_1|> Find a day-to-day image that looks similar to the provided image."
# noqa: E501
image
=
query
[
"image"
]
elif
query
[
"modality"
]
==
"text+image"
:
text
=
query
[
"text"
]
prompt
=
f
"<|image_1|> Represent the given image with the following question:
{
text
}
"
# noqa: E501
image
=
query
[
"image"
]
else
:
modality
=
query
[
'modality'
]
raise
ValueError
(
f
"Unsupported query modality: '
{
modality
}
'"
)
llm
=
LLM
(
model
=
"TIGER-Lab/VLM2Vec-Full"
,
task
=
"embedding"
,
trust_remote_code
=
True
,
mm_processor_kwargs
=
{
"num_crops"
:
4
},
)
return
ModelRequestData
(
llm
=
llm
,
prompt
=
prompt
,
image
=
image
,
)
def
get_query
(
modality
:
QueryModality
):
if
modality
==
"text"
:
return
TextQuery
(
modality
=
"text"
,
text
=
"A dog sitting in the grass"
)
if
modality
==
"image"
:
return
ImageQuery
(
modality
=
"image"
,
image
=
fetch_image
(
"https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg"
# noqa: E501
),
)
if
modality
==
"text+image"
:
return
TextImageQuery
(
modality
=
"text+image"
,
text
=
"A cat standing in the snow."
,
image
=
fetch_image
(
"https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg"
# noqa: E501
),
)
msg
=
f
"Modality
{
modality
}
is not supported."
raise
ValueError
(
msg
)
def
run_encode
(
model
:
str
,
modality
:
QueryModality
):
query
=
get_query
(
modality
)
req_data
=
model_example_map
[
model
](
query
)
mm_data
=
{}
if
req_data
.
image
is
not
None
:
mm_data
[
"image"
]
=
req_data
.
image
outputs
=
req_data
.
llm
.
encode
({
"prompt"
:
req_data
.
prompt
,
"multi_modal_data"
:
mm_data
,
})
for
output
in
outputs
:
print
(
output
.
outputs
.
embedding
)
def
main
(
args
:
Namespace
):
run_encode
(
args
.
model_name
,
args
.
modality
)
model_example_map
=
{
"e5_v"
:
run_e5_v
,
"vlm2vec"
:
run_vlm2vec
,
}
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Demo on using vLLM for offline inference with '
'vision language models for multimodal embedding'
)
parser
.
add_argument
(
'--model-name'
,
'-m'
,
type
=
str
,
default
=
"vlm2vec"
,
choices
=
model_example_map
.
keys
(),
help
=
'The name of the embedding model.'
)
parser
.
add_argument
(
'--modality'
,
type
=
str
,
default
=
"image"
,
choices
=
get_args
(
QueryModality
),
help
=
'Modality of the input.'
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/offline_inference_vision_language_multi_image.py
View file @
831540cf
"""
This example shows how to use vLLM for running offline inference with
multi-image input on vision language models
, using the chat template defined
by the model.
multi-image input on vision language models
for text generation,
using the chat template defined
by the model.
"""
from
argparse
import
Namespace
from
typing
import
List
,
NamedTuple
,
Optional
...
...
@@ -334,7 +334,8 @@ def main(args: Namespace):
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Demo on using vLLM for offline inference with '
'vision language models that support multi-image input'
)
'vision language models that support multi-image input for text '
'generation'
)
parser
.
add_argument
(
'--model-type'
,
'-m'
,
type
=
str
,
...
...
tests/conftest.py
View file @
831540cf
...
...
@@ -43,10 +43,12 @@ _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"example.txt"
)]
_LONG_PROMPTS
=
[
os
.
path
.
join
(
_TEST_DIR
,
"prompts"
,
"summary.txt"
)]
PromptImageInput
=
Union
[
List
[
Image
.
Image
],
List
[
List
[
Image
.
Image
]]]
PromptAudioInput
=
Union
[
List
[
Tuple
[
np
.
ndarray
,
int
]],
List
[
List
[
Tuple
[
np
.
ndarray
,
int
]]]]
PromptVideoInput
=
Union
[
List
[
np
.
ndarray
],
List
[
List
[
np
.
ndarray
]]]
_M
=
TypeVar
(
"_M"
)
_PromptMultiModalInput
=
Union
[
List
[
_M
],
List
[
List
[
_M
]]]
PromptImageInput
=
_PromptMultiModalInput
[
Image
.
Image
]
PromptAudioInput
=
_PromptMultiModalInput
[
Tuple
[
np
.
ndarray
,
int
]]
PromptVideoInput
=
_PromptMultiModalInput
[
np
.
ndarray
]
def
_read_prompts
(
filename
:
str
)
->
List
[
str
]:
...
...
@@ -318,12 +320,12 @@ class HfRunner:
"text"
:
prompt
,
"return_tensors"
:
"pt"
,
}
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
image
s
[
i
]
if
videos
is
not
None
and
videos
[
i
]
is
not
None
:
processor_kwargs
[
"videos"
]
=
video
s
[
i
]
if
audios
is
not
None
and
audios
[
i
]
is
not
None
:
audio
,
sr
=
audio
s
[
i
]
if
images
is
not
None
and
(
image
:
=
images
[
i
]
)
is
not
None
:
processor_kwargs
[
"images"
]
=
image
if
videos
is
not
None
and
(
video
:
=
videos
[
i
]
)
is
not
None
:
processor_kwargs
[
"videos"
]
=
video
if
audios
is
not
None
and
(
audio_tuple
:
=
audios
[
i
]
)
is
not
None
:
audio
,
sr
=
audio
_tuple
processor_kwargs
[
"audio"
]
=
audio
processor_kwargs
[
"sampling_rate"
]
=
sr
...
...
@@ -338,7 +340,7 @@ class HfRunner:
self
,
prompts
:
List
[
str
],
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
List
[
int
]],
List
[
str
]]]:
...
...
@@ -368,7 +370,7 @@ class HfRunner:
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
...
...
@@ -409,7 +411,7 @@ class HfRunner:
prompts
:
List
[
str
],
max_tokens
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
List
[
torch
.
Tensor
]]:
...
...
@@ -488,7 +490,7 @@ class HfRunner:
num_logprobs
:
int
,
images
:
Optional
[
PromptImageInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
videos
:
Optional
[
List
[
np
.
ndarray
]
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
**
kwargs
:
Any
,
)
->
List
[
TokensTextLogprobs
]:
all_inputs
=
self
.
get_inputs
(
prompts
,
...
...
@@ -657,15 +659,18 @@ class VllmRunner:
inputs
=
[
TextPrompt
(
prompt
=
prompt
)
for
prompt
in
prompts
]
if
images
is
not
None
:
for
i
,
image
in
enumerate
(
images
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"image"
:
image
}
if
image
is
not
None
:
inputs
[
i
][
"multi_modal_data"
]
=
{
"image"
:
image
}
if
videos
is
not
None
:
for
i
,
video
in
enumerate
(
videos
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"video"
:
video
}
if
video
is
not
None
:
inputs
[
i
][
"multi_modal_data"
]
=
{
"video"
:
video
}
if
audios
is
not
None
:
for
i
,
audio
in
enumerate
(
audios
):
inputs
[
i
][
"multi_modal_data"
]
=
{
"audio"
:
audio
}
if
audio
is
not
None
:
inputs
[
i
][
"multi_modal_data"
]
=
{
"audio"
:
audio
}
return
inputs
...
...
@@ -837,13 +842,20 @@ class VllmRunner:
returned_outputs
.
append
((
token_ids
,
texts
))
return
returned_outputs
def
encode
(
self
,
prompts
:
List
[
str
])
->
List
[
List
[
float
]]:
req_outputs
=
self
.
model
.
encode
(
prompts
)
outputs
=
[]
for
req_output
in
req_outputs
:
embedding
=
req_output
.
outputs
.
embedding
outputs
.
append
(
embedding
)
return
outputs
def
encode
(
self
,
prompts
:
List
[
str
],
images
:
Optional
[
PromptImageInput
]
=
None
,
videos
:
Optional
[
PromptVideoInput
]
=
None
,
audios
:
Optional
[
PromptAudioInput
]
=
None
,
)
->
List
[
List
[
float
]]:
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
req_outputs
=
self
.
model
.
encode
(
inputs
)
return
[
req_output
.
outputs
.
embedding
for
req_output
in
req_outputs
]
def
__enter__
(
self
):
return
self
...
...
tests/models/embedding/utils.py
View file @
831540cf
...
...
@@ -16,7 +16,8 @@ def check_embeddings_close(
for
prompt_idx
,
(
embeddings_0
,
embeddings_1
)
in
enumerate
(
zip
(
embeddings_0_lst
,
embeddings_1_lst
)):
assert
len
(
embeddings_0
)
==
len
(
embeddings_1
)
assert
len
(
embeddings_0
)
==
len
(
embeddings_1
),
(
f
"Length mismatch:
{
len
(
embeddings_0
)
}
vs.
{
len
(
embeddings_1
)
}
"
)
sim
=
F
.
cosine_similarity
(
torch
.
tensor
(
embeddings_0
),
torch
.
tensor
(
embeddings_1
),
...
...
tests/models/embedding/vision_language/test_llava_next.py
0 → 100644
View file @
831540cf
from
typing
import
List
,
Type
import
pytest
import
torch.nn.functional
as
F
from
transformers
import
AutoModelForVision2Seq
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
from
....utils
import
large_gpu_test
from
..utils
import
check_embeddings_close
llama3_template
=
'<|start_header_id|>user<|end_header_id|>
\n\n
{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
\n\n
\n
'
# noqa: E501
HF_TEXT_PROMPTS
=
[
# T -> X
llama3_template
.
format
(
"The label of the object is stop sign
\n
Summary above sentence in one word: "
# noqa: E501
),
# T -> X
llama3_template
.
format
(
"cherry blossom
\n
Summary above sentence in one word: "
),
]
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
# I -> X
"stop_sign"
:
llama3_template
.
format
(
"<image>
\n
Summary above image in one word: "
),
# I -> X
"cherry_blossom"
:
llama3_template
.
format
(
"<image>
\n
Summary above image in one word: "
),
})
MODELS
=
[
"royokong/e5-v"
]
def
_run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
input_texts
:
List
[
str
],
input_images
:
PromptImageInput
,
model
:
str
,
*
,
dtype
:
str
,
)
->
None
:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with
vllm_runner
(
model
,
task
=
"embedding"
,
dtype
=
dtype
,
max_model_len
=
4096
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
input_texts
,
images
=
input_images
)
with
hf_runner
(
model
,
dtype
=
dtype
,
auto_cls
=
AutoModelForVision2Seq
)
as
hf_model
:
# Patch the issue where image_token_id
# exceeds the maximum allowed vocab size
hf_model
.
model
.
resize_token_embeddings
(
hf_model
.
model
.
language_model
.
vocab_size
+
1
)
all_inputs
=
hf_model
.
get_inputs
(
input_texts
,
images
=
input_images
)
all_outputs
=
[]
for
inputs
in
all_inputs
:
# Based on: https://huggingface.co/royokong/e5-v
outputs
=
hf_model
.
model
(
**
hf_model
.
wrap_device
(
inputs
,
device
=
hf_model
.
model
.
device
.
type
),
return_dict
=
True
,
output_hidden_states
=
True
,
)
pooled_output
=
F
.
normalize
(
outputs
.
hidden_states
[
-
1
][
0
,
-
1
,
:],
dim
=-
1
)
all_outputs
.
append
(
pooled_output
.
tolist
())
hf_outputs
=
all_outputs
check_embeddings_close
(
embeddings_0_lst
=
hf_outputs
,
embeddings_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_text
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
dtype
:
str
,
)
->
None
:
input_texts_images
=
[(
text
,
None
)
for
text
in
HF_TEXT_PROMPTS
]
input_texts
=
[
text
for
text
,
_
in
input_texts_images
]
input_images
=
[
image
for
_
,
image
in
input_texts_images
]
_run_test
(
hf_runner
,
vllm_runner
,
input_texts
,
input_images
,
# type: ignore
model
,
dtype
=
dtype
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
dtype
:
str
,
)
->
None
:
input_texts_images
=
[
(
text
,
asset
.
pil_image
)
for
text
,
asset
in
zip
(
HF_IMAGE_PROMPTS
,
image_assets
)
]
input_texts
=
[
text
for
text
,
_
in
input_texts_images
]
input_images
=
[
image
for
_
,
image
in
input_texts_images
]
_run_test
(
hf_runner
,
vllm_runner
,
input_texts
,
input_images
,
model
,
dtype
=
dtype
,
)
tests/models/embedding/vision_language/test_phi3v.py
View file @
831540cf
from
typing
import
List
,
Type
import
pytest
import
torch.nn.functional
as
F
from
....conftest
import
IMAGE_ASSETS
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
PromptImageInput
,
VllmRunner
from
....utils
import
large_gpu_test
from
..utils
import
check_embeddings_close
HF_TEXT_PROMPTS
=
[
# T -> X
"Find me an everyday image that matches the given caption: The label of the object is stop sign"
,
# noqa: E501
# T -> X
"Retrieve an image of this caption: cherry blossom"
,
]
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
# T + I -> X
"stop_sign"
:
"<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign"
,
# noqa: E501
# I -> X
"cherry_blossom"
:
"<|image_1|> Represent the given image
with the following question: What is in the image
"
,
# noqa: E501
"<|image_1|> Represent the given image
for classification
"
,
# noqa: E501
})
MODELS
=
[
"TIGER-Lab/VLM2Vec-Full"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
def
_run_test
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
input_texts
:
List
[
str
],
input_images
:
PromptImageInput
,
model
:
str
,
*
,
dtype
:
str
,
)
->
None
:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with
vllm_runner
(
model
,
task
=
"embedding"
,
max_model_len
=
4096
,
max_num_seqs
=
2
,
dtype
=
dtype
,
with
vllm_runner
(
model
,
task
=
"embedding"
,
dtype
=
dtype
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
encode
(
example_prompt
s
)
vllm_outputs
=
vllm_model
.
encode
(
input_texts
,
images
=
input_image
s
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
all_inputs
=
hf_model
.
get_inputs
(
example_prompts
)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs
=
{
"_attn_implementation"
:
"eager"
}
with
hf_runner
(
model
,
dtype
=
dtype
,
model_kwargs
=
hf_model_kwargs
)
as
hf_model
:
all_inputs
=
hf_model
.
get_inputs
(
input_texts
,
images
=
input_images
)
all_outputs
=
[]
for
inputs
in
all_inputs
:
...
...
@@ -61,3 +72,53 @@ def test_models(
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_text
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
dtype
:
str
,
)
->
None
:
input_texts_images
=
[(
text
,
None
)
for
text
in
HF_TEXT_PROMPTS
]
input_texts
=
[
text
for
text
,
_
in
input_texts_images
]
input_images
=
[
image
for
_
,
image
in
input_texts_images
]
_run_test
(
hf_runner
,
vllm_runner
,
input_texts
,
input_images
,
# type: ignore
model
,
dtype
=
dtype
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
def
test_models_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
:
str
,
dtype
:
str
,
)
->
None
:
input_texts_images
=
[
(
text
,
asset
.
pil_image
)
for
text
,
asset
in
zip
(
HF_IMAGE_PROMPTS
,
image_assets
)
]
input_texts
=
[
text
for
text
,
_
in
input_texts_images
]
input_images
=
[
image
for
_
,
image
in
input_texts_images
]
_run_test
(
hf_runner
,
vllm_runner
,
input_texts
,
input_images
,
model
,
dtype
=
dtype
,
)
vllm/model_executor/models/llava_next.py
View file @
831540cf
...
...
@@ -13,11 +13,13 @@ from typing_extensions import NotRequired
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
DecoderOnlyInputs
,
InputContext
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.utils
import
is_list_of
from
.clip
import
(
CLIPVisionModel
,
dummy_image_for_clip
,
...
...
@@ -28,8 +30,8 @@ from .llava import LlavaMultiModalProjector
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_siglip_image_feature_size
,
get_siglip_patch_grid_length
,
input_processor_for_siglip
)
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
from
.utils
import
(
AutoWeightsLoader
,
embed_multimodal
,
flatten_bn
,
init_vllm_registered_model
)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT
=
MAX_IMAGE_FEATURE_SIZE_WIDTH
=
448
...
...
@@ -312,6 +314,10 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self
.
language_model
=
init_vllm_registered_model
(
config
.
text_config
,
cache_config
,
quant_config
)
# The same model class supports both language generation and embedding
# because the architecture name is the same
self
.
_pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
...
...
@@ -605,14 +611,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
inputs_embeds
=
embed_multimodal
(
input_ids
,
self
.
config
.
image_token_index
,
self
.
language_model
.
model
.
get_input_embeddings
,
lambda
_
:
self
.
_process_image_input
(
image_input
),
)
input_ids
=
None
else
:
inputs_embeds
=
None
...
...
@@ -641,6 +645,13 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
Optional
[
SamplerOutput
]:
return
self
.
language_model
.
sample
(
logits
,
sampling_metadata
)
def
pooler
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Optional
[
PoolerOutput
]:
return
self
.
_pooler
(
hidden_states
,
pooling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
loader
=
AutoWeightsLoader
(
self
)
loader
.
load_weights
(
weights
)
vllm/model_executor/models/phi3v.py
View file @
831540cf
...
...
@@ -467,8 +467,6 @@ def input_processor_for_phi3v(ctx: InputContext,
prompt_token_ids
=
inputs
[
"prompt_token_ids"
].
copy
()
print
(
"prompt_token_ids (old)"
,
prompt_token_ids
)
# masked placeholder with image token id
for
idx
in
image_idx
:
candidates
=
_get_image_placeholder_token_id_candidates
(
model_config
,
...
...
vllm/model_executor/models/registry.py
View file @
831540cf
...
...
@@ -94,6 +94,7 @@ _EMBEDDING_MODELS = {
"MistralModel"
:
(
"llama"
,
"LlamaEmbeddingModel"
),
"Qwen2ForRewardModel"
:
(
"qwen2_rm"
,
"Qwen2ForRewardModel"
),
# [Multimodal]
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
}
...
...
vllm/model_executor/models/utils.py
View file @
831540cf
import
itertools
from
dataclasses
import
dataclass
,
field
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
import
torch
import
torch.nn
as
nn
...
...
@@ -294,10 +294,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
_embedding_count_expression
(
inner
)
for
inner
in
embeddings
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
placeholder_token_id
:
int
)
->
torch
.
Tensor
:
def
_merge_multimodal_embeddings
(
inputs_embeds
:
torch
.
Tensor
,
is_multimodal
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
)
->
torch
.
Tensor
:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
...
...
@@ -306,8 +307,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
Note:
This updates ``inputs_embeds`` in place.
"""
mask
=
(
input_ids
==
placeholder_token_id
)
num_expected_tokens
=
mask
.
sum
().
item
()
num_expected_tokens
=
is_multimodal
.
sum
().
item
()
assert
isinstance
(
num_expected_tokens
,
int
)
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
...
...
@@ -317,10 +317,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
f
"Attempted to assign
{
expr
}
=
{
flattened
.
shape
[
0
]
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
flattened
inputs_embeds
[
is_multimodal
]
=
flattened
return
inputs_embeds
def
embed_multimodal
(
input_ids
:
torch
.
Tensor
,
multimodal_token_id
:
int
,
get_text_embeds
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
],
get_multimodal_embeds
:
Callable
[[
torch
.
Tensor
],
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]],
)
->
torch
.
Tensor
:
"""
Embed token IDs and multimodal inputs and combine their embeddings.
``multimodal_token_id`` is used to determine whether a token ID should
be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
Compared to ``merge_multimodal_embeddings`, this avoids running
``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
which causes issues when the placeholder token ID exceeds the
vocabulary size of the language model.
"""
is_multimodal
=
input_ids
==
multimodal_token_id
is_text
=
~
is_multimodal
text_embeds
=
get_text_embeds
(
input_ids
[
is_text
])
multimodal_embeds
=
get_multimodal_embeds
(
input_ids
[
is_multimodal
])
merged_embeds
=
torch
.
empty
(
(
input_ids
.
shape
[
0
],
text_embeds
.
shape
[
1
]),
dtype
=
text_embeds
.
dtype
,
device
=
text_embeds
.
device
,
)
merged_embeds
[
is_text
]
=
text_embeds
return
_merge_multimodal_embeddings
(
merged_embeds
,
is_multimodal
,
multimodal_embeds
,
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
NestedTensors
,
placeholder_token_id
:
int
,
)
->
torch
.
Tensor
:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
return
_merge_multimodal_embeddings
(
inputs_embeds
,
(
input_ids
==
placeholder_token_id
),
multimodal_embeddings
,
)
class
LayerFn
(
Protocol
):
def
__call__
(
self
,
prefix
:
str
)
->
torch
.
nn
.
Module
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment