Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
edf309eb
Unverified
Commit
edf309eb
authored
Feb 27, 2025
by
Isotr0py
Committed by
GitHub
Feb 27, 2025
Browse files
[VLM] Support multimodal inputs for Florence-2 models (#13320)
parent
788f284b
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1078 additions
and
117 deletions
+1078
-117
docs/source/models/supported_models.md
docs/source/models/supported_models.md
+7
-0
examples/offline_inference/florence2_inference.py
examples/offline_inference/florence2_inference.py
+23
-16
examples/offline_inference/vision_language.py
examples/offline_inference/vision_language.py
+17
-0
tests/conftest.py
tests/conftest.py
+3
-3
tests/models/decoder_only/audio_language/test_ultravox.py
tests/models/decoder_only/audio_language/test_ultravox.py
+2
-2
tests/models/encoder_decoder/vision_language/test_florence2.py
.../models/encoder_decoder/vision_language/test_florence2.py
+88
-51
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+3
-2
tests/models/registry.py
tests/models/registry.py
+5
-5
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+19
-8
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+893
-20
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+1
-1
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+13
-7
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+4
-2
No files found.
docs/source/models/supported_models.md
View file @
edf309eb
...
@@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
...
@@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
*
*
✅︎
*
✅︎
*
✅︎
*
✅︎
-
*
`Florence2ForConditionalGeneration`
*
Florence-2
*
T + I
*
`microsoft/Florence-2-base`
,
`microsoft/Florence-2-large`
etc.
*
*
*
-
*
`FuyuForCausalLM`
-
*
`FuyuForCausalLM`
*
Fuyu
*
Fuyu
*
T + I
*
T + I
...
...
examples/offline_inference/florence2_inference.py
View file @
edf309eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
'''
"""
Demonstrate prompting of text-to-text
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
encoder/decoder models, specifically Florence-2
'''
"""
# TODO(Isotr0py):
# TODO(Isotr0py):
# Move to offline_inference/vision_language.py
# Move to offline_inference/vision_language.py
# after porting vision backbone
# after porting vision backbone
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.assets.image
import
ImageAsset
dtype
=
"float"
# Create a Florence-2 encoder/decoder model instance
# Create a Florence-2 encoder/decoder model instance
llm
=
LLM
(
llm
=
LLM
(
model
=
"microsoft/Florence-2-
bas
e"
,
model
=
"microsoft/Florence-2-
larg
e"
,
tokenizer
=
"facebook/bart-
bas
e"
,
tokenizer
=
"facebook/bart-
larg
e"
,
dtype
=
dtype
,
max_num_seqs
=
8
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
prompts
=
[
prompts
=
[
"<CAPTION>"
,
"<DETAILED_CAPTION>"
,
"<MORE_DETAILED_CAPTION>"
,
{
# implicit prompt with task token
"<CAPTION_TO_PHRASE_GROUNDING>"
,
"<OD>"
,
"<DENSE_REGION_CAPTION>"
,
"prompt"
:
"<DETAILED_CAPTION>"
,
"<REGION_PROPOSAL>"
,
"<OCR>"
,
"<OCR_WITH_REGION>"
"multi_modal_data"
:
{
"image"
:
ImageAsset
(
"stop_sign"
).
pil_image
},
},
{
# explicit encoder/decoder prompt
"encoder_prompt"
:
{
"prompt"
:
"Describe in detail what is shown in the image."
,
"multi_modal_data"
:
{
"image"
:
ImageAsset
(
"cherry_blossom"
).
pil_image
},
},
"decoder_prompt"
:
""
,
},
]
]
# Create a sampling params object.
# Create a sampling params object.
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
top_p
=
1.0
,
top_p
=
1.0
,
min_tokens
=
0
,
min_tokens
=
0
,
max_tokens
=
20
,
max_tokens
=
128
,
)
)
# Generate output tokens from the prompts. The output is a list of
# Generate output tokens from the prompts. The output is a list of
...
@@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
...
@@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
# Print the outputs.
for
output
in
outputs
:
for
output
in
outputs
:
prompt
=
output
.
prompt
encoder_prompt
=
output
.
encoder_prompt
generated_text
=
output
.
outputs
[
0
].
text
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Encoder prompt:
{
encoder_prompt
!
r
}
, "
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
f
"Decoder prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
examples/offline_inference/vision_language.py
View file @
edf309eb
...
@@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
...
@@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
return
llm
,
prompt
,
stop_token_ids
return
llm
,
prompt
,
stop_token_ids
# Florence2
def
run_florence2
(
question
:
str
,
modality
:
str
):
assert
modality
==
"image"
llm
=
LLM
(
model
=
"microsoft/Florence-2-large"
,
tokenizer
=
"facebook/bart-large"
,
max_num_seqs
=
8
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
disable_mm_preprocessor_cache
=
args
.
disable_mm_preprocessor_cache
)
prompt
=
"<MORE_DETAILED_CAPTION>"
stop_token_ids
=
None
return
llm
,
prompt
,
stop_token_ids
# Fuyu
# Fuyu
def
run_fuyu
(
question
:
str
,
modality
:
str
):
def
run_fuyu
(
question
:
str
,
modality
:
str
):
assert
modality
==
"image"
assert
modality
==
"image"
...
@@ -571,6 +587,7 @@ model_example_map = {
...
@@ -571,6 +587,7 @@ model_example_map = {
"blip-2"
:
run_blip2
,
"blip-2"
:
run_blip2
,
"chameleon"
:
run_chameleon
,
"chameleon"
:
run_chameleon
,
"deepseek_vl_v2"
:
run_deepseek_vl2
,
"deepseek_vl_v2"
:
run_deepseek_vl2
,
"florence2"
:
run_florence2
,
"fuyu"
:
run_fuyu
,
"fuyu"
:
run_fuyu
,
"glm4v"
:
run_glm4v
,
"glm4v"
:
run_glm4v
,
"h2ovl_chat"
:
run_h2ovl
,
"h2ovl_chat"
:
run_h2ovl
,
...
...
tests/conftest.py
View file @
edf309eb
...
@@ -600,8 +600,8 @@ class HfRunner:
...
@@ -600,8 +600,8 @@ class HfRunner:
if
images
is
not
None
and
images
[
i
]
is
not
None
:
if
images
is
not
None
and
images
[
i
]
is
not
None
:
processor_kwargs
[
"images"
]
=
images
[
i
]
processor_kwargs
[
"images"
]
=
images
[
i
]
encoder_input
_id
s
=
self
.
wrap_device
(
encoder_inputs
=
self
.
wrap_device
(
self
.
processor
(
**
processor_kwargs
)
.
input_ids
,
self
.
processor
(
**
processor_kwargs
),
device
=
self
.
model
.
device
.
type
,
device
=
self
.
model
.
device
.
type
,
)
)
...
@@ -615,13 +615,13 @@ class HfRunner:
...
@@ -615,13 +615,13 @@ class HfRunner:
)
)
output
=
self
.
model
.
generate
(
output
=
self
.
model
.
generate
(
encoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
True
,
use_cache
=
True
,
do_sample
=
False
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
max_new_tokens
=
max_tokens
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
**
encoder_inputs
,
**
kwargs
,
**
kwargs
,
)
)
...
...
tests/models/decoder_only/audio_language/test_ultravox.py
View file @
edf309eb
...
@@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
...
@@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
from
....utils
import
RemoteOpenAIServer
from
....utils
import
RemoteOpenAIServer
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
MODEL_NAME
=
"fixie-ai/ultravox-v0_
5-llama-3_2-1b
"
MODEL_NAME
=
"fixie-ai/ultravox-v0_
4
"
AudioTuple
=
Tuple
[
np
.
ndarray
,
int
]
AudioTuple
=
Tuple
[
np
.
ndarray
,
int
]
...
@@ -187,7 +187,7 @@ def run_multi_audio_test(
...
@@ -187,7 +187,7 @@ def run_multi_audio_test(
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
half
"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"
bfloat16
"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[
@
pytest
.
mark
.
parametrize
(
"vllm_kwargs"
,
[
...
...
tests/models/encoder_decoder/vision_language/test_florence2.py
View file @
edf309eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
functools
import
partial
from
typing
import
Optional
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
pytest
import
pytest
from
PIL
import
Image
from
PIL
import
Image
from
vllm.inputs.data
import
ExplicitEncoderDecoderPrompt
from
vllm.inputs.data
import
ExplicitEncoderDecoderPrompt
,
TextPrompt
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
HfRunner
,
VllmRunner
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
VllmRunner
,
_ImageAssets
from
...utils
import
check_logprobs_close
from
...utils
import
check_logprobs_close
Florence2Prompt
=
partial
(
ExplicitEncoderDecoderPrompt
,
decoder_prompt
=
None
,
mm_processor_kwargs
=
None
)
MODELS
=
[
"microsoft/Florence-2-base"
]
MODELS
=
[
"microsoft/Florence-2-base"
]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER
=
"facebook/bart-base"
TOKENIZER
=
"facebook/bart-base"
PROMPTS
=
[
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
Florence2Prompt
(
encoder_prompt
=
"<CAPTION>"
),
"stop_sign"
:
Florence2Prompt
(
encoder_prompt
=
"<DETAILED_CAPTION>"
),
"<CAPTION>"
,
# special task token
Florence2Prompt
(
encoder_prompt
=
"<MORE_DETAILED_CAPTION>"
),
"cherry_blossom"
:
Florence2Prompt
(
encoder_prompt
=
"<CAPTION_TO_PHRASE_GROUNDING>"
),
"Describe in detail what is shown in the image."
,
Florence2Prompt
(
encoder_prompt
=
"<DENSE_REGION_CAPTION>"
),
})
Florence2Prompt
(
encoder_prompt
=
"<REGION_PROPOSAL>"
),
Florence2Prompt
(
encoder_prompt
=
"<OCR_WITH_REGION>"
),
Florence2Prompt
(
encoder_prompt
=
"<OCR>"
),
Florence2Prompt
(
encoder_prompt
=
"<OD>"
),
]
def
get_hf_images_prompts
(
prompts_
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
TextPrompt
]],
)
->
tuple
[
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
list
[
Image
.
Image
]]:
prompts
,
images
=
[],
[]
for
prompt
in
prompts_
:
encoder_prompt
=
prompt
[
"encoder_prompt"
]
prompts
.
append
(
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
[
"prompt"
],
decoder_prompt
=
None
,
))
images
.
append
(
encoder_prompt
[
"multi_modal_data"
][
"image"
])
return
prompts
,
images
def
vllm_to_hf_output
(
vllm_output
:
Tuple
[
List
[
int
],
str
,
Optional
[
SampleLogprobs
]],
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
hf_output_str
=
"</s><s>"
+
output_str
+
"</s>"
def
hf_to_vllm_output
(
hf_output
:
tuple
[
list
[
int
],
str
,
Optional
[
SampleLogprobs
]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids
,
output_str
,
out_logprobs
=
hf_output
return
output_ids
,
hf_output_str
,
out_logprobs
output_str
=
output_str
.
replace
(
"</s>"
,
""
).
replace
(
"<s>"
,
""
)
output_ids
=
[
ids
for
ids
in
output_ids
if
ids
not
in
[
0
,
2
]]
return
output_ids
,
output_str
,
out_logprobs
def
run_test
(
def
run_test
(
hf_runner
:
Type
[
HfRunner
],
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
vllm_runner
:
Type
[
VllmRunner
],
prompts
:
L
ist
[
ExplicitEncoderDecoderPrompt
],
inputs
:
list
[
l
ist
[
ExplicitEncoderDecoderPrompt
]
]
,
model
:
str
,
model
:
str
,
*
,
*
,
dtype
:
str
,
dtype
:
str
,
...
@@ -56,46 +63,76 @@ def run_test(
...
@@ -56,46 +63,76 @@ def run_test(
distributed_executor_backend
:
Optional
[
str
]
=
None
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
with
vllm_runner
(
model
,
with
vllm_runner
(
model
,
max_num_seqs
=
8
,
tokenizer_name
=
TOKENIZER
,
tokenizer_name
=
TOKENIZER
,
dtype
=
dtype
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
vllm_outputs_per_case
=
[
prompts
,
max_tokens
,
num_logprobs
)
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
)
for
prompts
in
inputs
]
hf_inputs
=
[
get_hf_images_prompts
(
prompts
)
for
prompts
in
inputs
]
# Florence-2 processors require image inputs
dummy_image
=
Image
.
new
(
mode
=
"RGB"
,
size
=
(
2
,
2
))
with
hf_runner
(
model
,
dtype
=
dtype
,
skip_tokenizer_init
=
True
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
,
skip_tokenizer_init
=
True
)
as
hf_model
:
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
language_model
.
lm_head
hf_model
.
model
.
language_model
.
lm_head
hf_outputs
=
(
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
hf_outputs
_per_case
=
[
prompts
,
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
max_tokens
,
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
num_logprobs
,
for
prompts
,
images
in
hf_inputs
images
=
[
dummy_image
]
*
len
(
prompts
),
]
))
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_case
,
check_logprobs_close
(
vllm_outputs_per_case
):
outputs_0_lst
=
hf_outputs
,
check_logprobs_close
(
outputs_
1
_lst
=
[
outputs_
0
_lst
=
[
hf_to_vllm_output
(
output
)
for
output
in
hf_outputs
],
vllm_to_hf_output
(
vllm_output
)
for
vllm_output
in
vllm_outputs
outputs_1_lst
=
vllm_outputs
,
]
,
name_0
=
"hf"
,
name_
0
=
"
hf
"
,
name_
1
=
"
vllm
"
,
name_1
=
"vllm"
,
)
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
,
vllm_runner
,
model
,
dtype
,
max_tokens
,
def
test_models
(
hf_runner
:
Type
[
HfRunner
],
vllm_runner
:
Type
[
VllmRunner
],
num_logprobs
)
->
None
:
image_assets
:
_ImageAssets
,
model
:
str
,
size_factors
:
list
[
int
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[[
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
TextPrompt
(
prompt
=
prompt
,
multi_modal_data
=
{
"image"
:
rescale_image_size
(
image
,
factor
)}),
decoder_prompt
=
None
,
)
for
factor
in
size_factors
]
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
run_test
(
run_test
(
hf_runner
,
hf_runner
,
vllm_runner
,
vllm_runner
,
PROMPTS
,
inputs_per_image
,
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
...
...
tests/models/multimodal/processing/test_common.py
View file @
edf309eb
...
@@ -29,8 +29,8 @@ def _test_processing_correctness(
...
@@ -29,8 +29,8 @@ def _test_processing_correctness(
model_config
=
ModelConfig
(
model_config
=
ModelConfig
(
model_id
,
model_id
,
task
=
"auto"
,
task
=
"auto"
,
tokenizer
=
model_id
,
tokenizer
=
model_info
.
tokenizer
or
model_id
,
tokenizer_mode
=
"auto"
,
tokenizer_mode
=
model_info
.
tokenizer_mode
,
trust_remote_code
=
model_info
.
trust_remote_code
,
trust_remote_code
=
model_info
.
trust_remote_code
,
seed
=
0
,
seed
=
0
,
dtype
=
"float16"
,
dtype
=
"float16"
,
...
@@ -151,6 +151,7 @@ def _test_processing_correctness(
...
@@ -151,6 +151,7 @@ def _test_processing_correctness(
"Salesforce/blip2-opt-2.7b"
,
"Salesforce/blip2-opt-2.7b"
,
"facebook/chameleon-7b"
,
"facebook/chameleon-7b"
,
"deepseek-ai/deepseek-vl2-tiny"
,
"deepseek-ai/deepseek-vl2-tiny"
,
"microsoft/Florence-2-base"
,
"adept/fuyu-8b"
,
"adept/fuyu-8b"
,
"THUDM/glm-4v-9b"
,
"THUDM/glm-4v-9b"
,
"h2oai/h2ovl-mississippi-800m"
,
"h2oai/h2ovl-mississippi-800m"
,
...
...
tests/models/registry.py
View file @
edf309eb
...
@@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration"
:
_HfExamplesInfo
(
"microsoft/Florence-2-base"
,
# noqa: E501
tokenizer
=
"facebook/bart-base"
,
trust_remote_code
=
True
),
# noqa: E501
}
}
_EMBEDDING_EXAMPLE_MODELS
=
{
_EMBEDDING_EXAMPLE_MODELS
=
{
...
@@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras
=
{
"v0.5"
:
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
},
# noqa: E501
extras
=
{
"v0.5"
:
"fixie-ai/ultravox-v0_5-llama-3_2-1b"
},
# noqa: E501
trust_remote_code
=
True
),
trust_remote_code
=
True
),
# [Encoder-decoder]
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration"
:
_HfExamplesInfo
(
"microsoft/Florence-2-base"
,
# noqa: E501
tokenizer
=
"facebook/bart-base"
,
trust_remote_code
=
True
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
}
}
...
...
vllm/model_executor/models/bart.py
View file @
edf309eb
...
@@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
...
@@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
def
forward
(
positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
input_ids
input_ids
...
@@ -602,7 +606,8 @@ class BartEncoder(nn.Module):
...
@@ -602,7 +606,8 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
Decoder output torch.Tensor
"""
"""
# retrieve input_ids and inputs_embeds
# retrieve input_ids and inputs_embeds
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
positions
)
embed_pos
=
self
.
embed_positions
(
positions
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
...
@@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
...
@@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
def
forward
(
decoder_positions
:
torch
.
Tensor
,
self
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
decoder_input_ids
decoder_input_ids
...
@@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
...
@@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
Returns:
Returns:
Decoder output torch.Tensor
Decoder output torch.Tensor
"""
"""
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
else
:
decoder_positions
=
inputs_embeds
[:,
-
1
]
# embed positions
# embed positions
embed_pos
=
self
.
embed_positions
(
decoder_positions
)
embed_pos
=
self
.
embed_positions
(
decoder_positions
)
...
...
vllm/model_executor/models/florence2.py
View file @
edf309eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
math
import
math
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
from
functools
import
cached_property
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
TypedDict
,
Union
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BatchFeature
,
PretrainedConfig
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
...
@@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
...
@@ -14,11 +19,567 @@ from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead
,
BartParallelLMHead
,
BartScaledWordEmbedding
)
BartScaledWordEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.parse
import
MultiModalDataDict
,
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
,
PromptReplacementDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
AutoWeightsLoader
from
.interfaces
import
SupportsMultiModal
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
merge_multimodal_embeddings
class
Florence2ImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
torch
.
Tensor
"""Shape: (batch_size, num_channel, height, width)"""
# ViT implementation are all copied from
# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
class
LearnedAbsolutePositionEmbedding2D
(
nn
.
Module
):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def
__init__
(
self
,
embedding_dim
=
256
,
num_pos
=
50
):
super
().
__init__
()
self
.
row_embeddings
=
nn
.
Embedding
(
num_pos
,
embedding_dim
//
2
)
self
.
column_embeddings
=
nn
.
Embedding
(
num_pos
,
embedding_dim
-
(
embedding_dim
//
2
))
def
forward
(
self
,
pixel_values
):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if
len
(
pixel_values
.
shape
)
!=
4
:
raise
ValueError
(
'pixel_values must be a 4D tensor'
)
height
,
width
=
pixel_values
.
shape
[
1
:
3
]
width_values
=
torch
.
arange
(
width
,
device
=
pixel_values
.
device
)
height_values
=
torch
.
arange
(
height
,
device
=
pixel_values
.
device
)
x_emb
=
self
.
column_embeddings
(
width_values
)
y_emb
=
self
.
row_embeddings
(
height_values
)
# (height, width, embedding_dim * 2)
pos
=
torch
.
cat
([
x_emb
.
unsqueeze
(
0
).
repeat
(
height
,
1
,
1
),
y_emb
.
unsqueeze
(
1
).
repeat
(
1
,
width
,
1
)
],
dim
=-
1
)
# (embedding_dim * 2, height, width)
pos
=
pos
.
permute
(
2
,
0
,
1
)
pos
=
pos
.
unsqueeze
(
0
)
# (batch_size, embedding_dim * 2, height, width)
pos
=
pos
.
repeat
(
pixel_values
.
shape
[
0
],
1
,
1
,
1
)
# (batch_size, height, width, embedding_dim * 2)
pos
=
pos
.
permute
(
0
,
2
,
3
,
1
)
return
pos
class
PositionalEmbeddingCosine1D
(
nn
.
Module
):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def
__init__
(
self
,
embed_dim
:
int
=
512
,
max_seq_len
:
int
=
1024
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
max_seq_len
=
max_seq_len
# Generate the sinusoidal arrays.
factor
=
math
.
log
(
10000
)
denominator
=
torch
.
exp
(
-
factor
*
torch
.
arange
(
0
,
self
.
embed_dim
,
2
)
/
self
.
embed_dim
)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies
=
\
torch
.
arange
(
0
,
self
.
max_seq_len
)
\
.
reshape
(
self
.
max_seq_len
,
1
)
*
denominator
pos_idx_to_embed
=
torch
.
zeros
((
self
.
max_seq_len
,
self
.
embed_dim
))
# Populate uneven entries.
pos_idx_to_embed
[:,
0
::
2
]
=
torch
.
sin
(
frequencies
)
pos_idx_to_embed
[:,
1
::
2
]
=
torch
.
cos
(
frequencies
)
# Save the positional embeddings in a constant buffer.
# self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
self
.
pos_idx_to_embed
=
nn
.
Parameter
(
pos_idx_to_embed
,
requires_grad
=
False
)
def
forward
(
self
,
seq_embeds
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len
=
len
(
seq_embeds
.
shape
)
assert
2
<=
shape_len
<=
3
len_seq
=
seq_embeds
.
size
(
-
2
)
assert
len_seq
<=
self
.
max_seq_len
pos_embeds
=
self
.
pos_idx_to_embed
[
0
:
seq_embeds
.
size
(
-
2
),
:]
# Adapt pre-computed positional embeddings to the input.
if
shape_len
==
3
:
pos_embeds
=
pos_embeds
.
view
(
(
1
,
pos_embeds
.
size
(
0
),
pos_embeds
.
size
(
1
)))
return
pos_embeds
class
MySequential
(
nn
.
Sequential
):
def
forward
(
self
,
*
inputs
):
for
module
in
self
.
_modules
.
values
():
if
isinstance
(
inputs
,
tuple
):
inputs
=
module
(
*
inputs
)
else
:
inputs
=
module
(
inputs
)
return
inputs
class
PreNorm
(
nn
.
Module
):
def
__init__
(
self
,
norm
,
fn
):
super
().
__init__
()
self
.
norm
=
norm
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
shortcut
=
x
if
self
.
norm
is
not
None
:
x
,
size
=
self
.
fn
(
self
.
norm
(
x
),
*
args
,
**
kwargs
)
else
:
x
,
size
=
self
.
fn
(
x
,
*
args
,
**
kwargs
)
x
=
shortcut
+
x
return
x
,
size
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
net
=
nn
.
Sequential
(
OrderedDict
([(
"fc1"
,
nn
.
Linear
(
in_features
,
hidden_features
)),
(
"act"
,
act_layer
()),
(
"fc2"
,
nn
.
Linear
(
hidden_features
,
out_features
))]))
def
forward
(
self
,
x
,
size
):
return
self
.
net
(
x
),
size
class
DepthWiseConv2d
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
kernel_size
,
padding
,
stride
,
bias
=
True
,
):
super
().
__init__
()
self
.
dw
=
nn
.
Conv2d
(
dim_in
,
dim_in
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
dim_in
,
stride
=
stride
,
bias
=
bias
)
def
forward
(
self
,
x
,
size
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
size
assert
N
==
H
*
W
x
=
self
.
dw
(
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
))
size
=
(
x
.
size
(
-
2
),
x
.
size
(
-
1
))
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
,
size
class
ConvEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
patch_size
=
7
,
in_chans
=
3
,
embed_dim
=
64
,
stride
=
4
,
padding
=
2
,
norm_layer
=
None
,
pre_norm
=
True
):
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
padding
)
dim_norm
=
in_chans
if
pre_norm
else
embed_dim
self
.
norm
=
norm_layer
(
dim_norm
)
if
norm_layer
else
None
self
.
pre_norm
=
pre_norm
def
forward
(
self
,
x
,
size
):
H
,
W
=
size
if
len
(
x
.
size
())
==
3
:
if
self
.
norm
and
self
.
pre_norm
:
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
H
,
w
=
W
)
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
norm
and
not
self
.
pre_norm
:
x
=
self
.
norm
(
x
)
return
x
,
(
H
,
W
)
class
ChannelAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
groups
=
8
,
qkv_bias
=
True
):
super
().
__init__
()
self
.
groups
=
groups
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
def
forward
(
self
,
x
,
size
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
groups
,
C
//
self
.
groups
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
(
float
(
N
)
**-
0.5
)
attention
=
q
.
transpose
(
-
1
,
-
2
)
@
k
attention
=
attention
.
softmax
(
dim
=-
1
)
x
=
(
attention
@
v
.
transpose
(
-
1
,
-
2
)).
transpose
(
-
1
,
-
2
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
return
x
,
size
class
ChannelBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
groups
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
):
super
().
__init__
()
self
.
conv1
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_attn
else
None
self
.
channel_attn
=
PreNorm
(
norm_layer
(
dim
),
ChannelAttention
(
dim
,
groups
=
groups
,
qkv_bias
=
qkv_bias
),
)
self
.
conv2
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_ffn
else
None
self
.
ffn
=
PreNorm
(
norm_layer
(
dim
),
Mlp
(
in_features
=
dim
,
hidden_features
=
int
(
dim
*
mlp_ratio
),
act_layer
=
act_layer
),
)
def
forward
(
self
,
x
,
size
):
if
self
.
conv1
:
x
,
size
=
self
.
conv1
(
x
,
size
)
x
,
size
=
self
.
channel_attn
(
x
,
size
)
if
self
.
conv2
:
x
,
size
=
self
.
conv2
(
x
,
size
)
x
,
size
=
self
.
ffn
(
x
,
size
)
return
x
,
size
def
window_partition
(
x
,
window_size
:
int
):
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
batch_size
:
int
,
window_size
:
int
,
H
:
int
,
W
:
int
):
B
=
batch_size
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
window_size
,
qkv_bias
=
True
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
float
(
head_dim
)
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
size
):
H
,
W
=
size
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
x
=
x
.
view
(
B
,
H
,
W
,
C
)
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
x
=
window_partition
(
x
,
self
.
window_size
)
x
=
x
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
attn
=
self
.
softmax
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
# merge windows
x
=
x
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
x
=
window_reverse
(
x
,
B
,
self
.
window_size
,
Hp
,
Wp
)
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
return
x
,
size
class
SpatialBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
):
super
().
__init__
()
self
.
conv1
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_attn
else
None
self
.
window_attn
=
PreNorm
(
norm_layer
(
dim
),
WindowAttention
(
dim
,
num_heads
,
window_size
,
qkv_bias
=
qkv_bias
),
)
self
.
conv2
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_ffn
else
None
self
.
ffn
=
PreNorm
(
norm_layer
(
dim
),
Mlp
(
in_features
=
dim
,
hidden_features
=
int
(
dim
*
mlp_ratio
),
act_layer
=
act_layer
),
)
def
forward
(
self
,
x
,
size
):
if
self
.
conv1
:
x
,
size
=
self
.
conv1
(
x
,
size
)
x
,
size
=
self
.
window_attn
(
x
,
size
)
if
self
.
conv2
:
x
,
size
=
self
.
conv2
(
x
,
size
)
x
,
size
=
self
.
ffn
(
x
,
size
)
return
x
,
size
class
DaViT
(
nn
.
Module
):
def
__init__
(
self
,
in_chans
=
3
,
num_classes
=
1000
,
depths
=
(
1
,
1
,
3
,
1
),
patch_size
=
(
7
,
2
,
2
,
2
),
patch_stride
=
(
4
,
2
,
2
,
2
),
patch_padding
=
(
3
,
0
,
0
,
0
),
patch_prenorm
=
(
False
,
False
,
False
,
False
),
embed_dims
=
(
64
,
128
,
192
,
256
),
num_heads
=
(
3
,
6
,
12
,
24
),
num_groups
=
(
3
,
6
,
12
,
24
),
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
enable_checkpoint
=
False
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
,
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
num_groups
=
num_groups
self
.
num_stages
=
len
(
self
.
embed_dims
)
self
.
enable_checkpoint
=
enable_checkpoint
assert
self
.
num_stages
==
len
(
self
.
num_heads
)
==
len
(
self
.
num_groups
)
num_stages
=
len
(
embed_dims
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
)
*
2
)
]
depth_offset
=
0
convs
=
[]
blocks
=
[]
for
i
in
range
(
num_stages
):
conv_embed
=
ConvEmbed
(
patch_size
=
patch_size
[
i
],
stride
=
patch_stride
[
i
],
padding
=
patch_padding
[
i
],
in_chans
=
in_chans
if
i
==
0
else
self
.
embed_dims
[
i
-
1
],
embed_dim
=
self
.
embed_dims
[
i
],
norm_layer
=
norm_layer
,
pre_norm
=
patch_prenorm
[
i
])
convs
.
append
(
conv_embed
)
block
=
MySequential
(
*
[
MySequential
(
OrderedDict
([(
'spatial_block'
,
SpatialBlock
(
embed_dims
[
i
],
num_heads
[
i
],
window_size
,
drop_path_rate
=
dpr
[
depth_offset
+
j
*
2
],
qkv_bias
=
qkv_bias
,
mlp_ratio
=
mlp_ratio
,
conv_at_attn
=
conv_at_attn
,
conv_at_ffn
=
conv_at_ffn
,
)),
(
'channel_block'
,
ChannelBlock
(
embed_dims
[
i
],
num_groups
[
i
],
drop_path_rate
=
dpr
[
depth_offset
+
j
*
2
+
1
],
qkv_bias
=
qkv_bias
,
mlp_ratio
=
mlp_ratio
,
conv_at_attn
=
conv_at_attn
,
conv_at_ffn
=
conv_at_ffn
,
))]))
for
j
in
range
(
depths
[
i
])
])
blocks
.
append
(
block
)
depth_offset
+=
depths
[
i
]
*
2
self
.
convs
=
nn
.
ModuleList
(
convs
)
self
.
blocks
=
nn
.
ModuleList
(
blocks
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
@
property
def
dim_out
(
self
):
return
self
.
embed_dims
[
-
1
]
def
forward_features_unpool
(
self
,
x
):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size
=
(
x
.
size
(
2
),
x
.
size
(
3
))
for
conv
,
block
in
zip
(
self
.
convs
,
self
.
blocks
):
x
,
input_size
=
conv
(
x
,
input_size
)
x
,
input_size
=
block
(
x
,
input_size
)
return
x
def
forward_features
(
self
,
x
):
x
=
self
.
forward_features_unpool
(
x
)
# (batch_size, num_tokens, token_dim)
x
=
self
.
avgpool
(
x
.
transpose
(
1
,
2
))
# (batch_size, 1, num_tokens)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
norms
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
depths
=
config
.
depths
,
embed_dims
=
config
.
dim_embed
,
num_heads
=
config
.
num_heads
,
num_groups
=
config
.
num_groups
,
patch_size
=
config
.
patch_size
,
patch_stride
=
config
.
patch_stride
,
patch_padding
=
config
.
patch_padding
,
patch_prenorm
=
config
.
patch_prenorm
,
drop_path_rate
=
config
.
drop_path_rate
,
window_size
=
config
.
window_size
,
)
# Language backbone and processor implementation
class
Florence2LanguageModel
(
nn
.
Module
):
class
Florence2LanguageModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module):
...
@@ -47,9 +608,14 @@ class Florence2LanguageModel(nn.Module):
self
.
encoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
self
.
encoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
self
.
decoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
self
.
decoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
encoder_input_ids
:
torch
.
Tensor
,
self
,
encoder_positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
r
"""
Args:
Args:
input_ids
input_ids
...
@@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module):
...
@@ -68,11 +634,12 @@ class Florence2LanguageModel(nn.Module):
encoder_hidden_states
=
None
encoder_hidden_states
=
None
if
encoder_input_ids
.
numel
()
>
0
:
if
inputs_embeds
is
not
None
or
encoder_input_ids
.
numel
()
>
0
:
# Run encoder attention if a non-zero number of encoder tokens
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
)
positions
=
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
# decoder outputs consists of
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
# (dec_features, past_key_value, dec_hidden, dec_attn)
...
@@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
...
@@ -112,6 +679,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
r
"""
r
"""
...
@@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
...
@@ -127,8 +695,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
Returns:
Returns:
Output torch.Tensor
Output torch.Tensor
"""
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
)
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
encoder
.
embed_tokens
(
input_ids
)
def
compute_logits
(
def
compute_logits
(
self
,
self
,
...
@@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
...
@@ -177,21 +752,312 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
return
loaded_params
return
loaded_params
class
Florence2ForConditionalGeneration
(
nn
.
Module
):
class
Florence2ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
()
def
get_hf_processor
(
self
):
return
self
.
ctx
.
get_hf_processor
()
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_max_image_tokens
(
self
)
->
int
:
processor_config
=
self
.
ctx
.
get_hf_image_processor_config
()
return
processor_config
[
"image_seq_length"
]
def
get_mm_max_tokens_per_item
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
int
]:
return
{
"image"
:
self
.
get_max_image_tokens
()}
class
Florence2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Florence2ProcessingInfo
]):
def
get_dummy_processor_inputs
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
ProcessorInputs
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
=
target_height
=
self
.
info
.
get_hf_config
().
projection_dim
mm_data
=
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
return
ProcessorInputs
(
prompt_text
=
""
,
mm_data
=
mm_data
,
)
class
Florence2MultiModalProcessor
(
EncDecMultiModalProcessor
[
Florence2ProcessingInfo
]):
def
_hf_processor_applies_repl
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
prompt
def
create_decoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
[
self
.
info
.
get_hf_config
().
eos_token_id
]
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
if
mm_data
:
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
)
else
:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
=
hf_processor
.
tokenizer
prompt
=
hf_processor
.
_construct_prompts
([
prompt
])[
0
]
processed_outputs
=
tokenizer
(
prompt
,
add_special_tokens
=
True
,
return_tensors
=
"pt"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
))
def
_get_prompt_replacements
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargs
,
)
->
list
[
PromptReplacement
]:
hf_config
=
self
.
info
.
get_hf_config
()
pad_token_id
=
hf_config
.
pad_token_id
bos_token_id
=
hf_config
.
bos_token_id
num_image_tokens
=
self
.
info
.
get_max_image_tokens
()
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
bos_token_id
],
replacement
=
PromptReplacementDetails
(
full
=
image_tokens
+
[
bos_token_id
],
features
=
image_tokens
,
),
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
Florence2MultiModalProcessor
,
info
=
Florence2ProcessingInfo
,
dummy_inputs
=
Florence2DummyInputsBuilder
)
class
Florence2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
processor_config
=
vllm_config
.
model_config
.
hf_image_processor_config
# TODO(Isotr0py): Add vision backbone
self
.
config
=
config
self
.
vision_config
=
config
.
vision_config
self
.
processor_config
=
processor_config
assert
config
.
vision_config
.
model_type
==
'davit'
,
(
'only DaViT is supported for now'
)
self
.
vision_tower
=
DaViT
.
from_config
(
config
=
config
.
vision_config
)
self
.
_build_image_projection_layers
(
config
)
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
f
"
{
prefix
}
.language_model"
,
prefix
=
f
"
{
prefix
}
.language_model"
,
)
)
self
.
pad_token_id
=
config
.
pad_token_id
@
property
def
_build_image_projection_layers
(
self
,
config
:
PretrainedConfig
):
image_dim_out
=
config
.
vision_config
.
dim_embed
[
-
1
]
dim_projection
=
config
.
vision_config
.
projection_dim
self
.
image_projection
=
nn
.
Parameter
(
torch
.
empty
(
image_dim_out
,
dim_projection
))
self
.
image_proj_norm
=
nn
.
LayerNorm
(
dim_projection
)
image_pos_embed_config
=
config
.
vision_config
.
image_pos_embed
if
image_pos_embed_config
[
'type'
]
==
'learned_abs_2d'
:
self
.
image_pos_embed
=
LearnedAbsolutePositionEmbedding2D
(
embedding_dim
=
image_dim_out
,
num_pos
=
image_pos_embed_config
[
'max_pos_embeddings'
])
else
:
raise
NotImplementedError
(
"Florence2 only supports learned_abs_2d "
"as image position embedding."
)
self
.
image_feature_source
=
config
.
vision_config
.
image_feature_source
# temporal embedding
visual_temporal_embedding_config
=
(
self
.
vision_config
.
visual_temporal_embedding
)
if
visual_temporal_embedding_config
[
'type'
]
==
'COSINE'
:
self
.
visual_temporal_embed
=
PositionalEmbeddingCosine1D
(
embed_dim
=
image_dim_out
,
max_seq_len
=
visual_temporal_embedding_config
[
'max_temporal_embeddings'
])
else
:
raise
NotImplementedError
(
'Florence2 only supports COSINE as temporal embedding.'
)
@
cached_property
def
sampler
(
self
):
def
sampler
(
self
):
return
self
.
language_model
.
sampler
if
hasattr
(
self
.
language_model
,
"sampler"
):
return
self
.
language_model
.
sampler
return
get_sampler
()
def
_validate_pixel_values
(
self
,
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
size
=
self
.
processor_config
[
"size"
]
h
,
w
=
size
[
"height"
],
size
[
"width"
]
expected_dims
=
(
3
,
h
,
w
)
def
_validate_shape
(
d
:
torch
.
Tensor
):
actual_dims
=
tuple
(
d
.
shape
)
if
actual_dims
!=
expected_dims
:
expected_expr
=
tuple
(
*
map
(
str
,
expected_dims
))
raise
ValueError
(
"The expected shape of pixel values per batch "
f
"is
{
expected_expr
}
. You supplied
{
tuple
(
d
.
shape
)
}
."
)
for
d
in
data
:
_validate_shape
(
d
)
return
data
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
pixel_values
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
List
[
List
[
torch
.
Tensor
]],
List
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
return
Florence2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)),
)
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
_encode_image
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dtype
=
next
(
self
.
vision_tower
.
parameters
()).
dtype
pixel_values
=
pixel_values
.
to
(
dtype
)
batch_size
,
T
=
pixel_values
.
size
(
0
),
1
x
=
self
.
vision_tower
.
forward_features_unpool
(
pixel_values
)
if
self
.
image_pos_embed
is
not
None
:
x
=
x
.
view
(
batch_size
*
T
,
-
1
,
x
.
shape
[
-
1
])
num_tokens
=
x
.
shape
[
-
2
]
h
,
w
=
int
(
num_tokens
**
0.5
),
int
(
num_tokens
**
0.5
)
assert
h
*
w
==
num_tokens
,
(
'only support square feature maps for now'
)
x
=
x
.
view
(
batch_size
*
T
,
h
,
w
,
x
.
shape
[
-
1
])
pos_embed
=
self
.
image_pos_embed
(
x
)
x
=
x
+
pos_embed
x
=
x
.
view
(
batch_size
,
T
*
h
*
w
,
x
.
shape
[
-
1
])
if
self
.
visual_temporal_embed
is
not
None
:
visual_temporal_embed
=
self
.
visual_temporal_embed
(
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])[:,
:,
0
])
x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])
+
visual_temporal_embed
.
view
(
1
,
T
,
1
,
x
.
shape
[
-
1
])
x_feat_dict
=
{}
spatial_avg_pool_x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
]).
mean
(
dim
=
2
)
x_feat_dict
[
'spatial_avg_pool'
]
=
spatial_avg_pool_x
temporal_avg_pool_x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
]).
mean
(
dim
=
1
)
x_feat_dict
[
'temporal_avg_pool'
]
=
temporal_avg_pool_x
x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])[:,
-
1
]
x_feat_dict
[
'last_frame'
]
=
x
new_x
=
[]
for
_image_feature_source
in
self
.
image_feature_source
:
if
_image_feature_source
not
in
x_feat_dict
:
raise
ValueError
(
'invalid image feature source: {}'
.
format
(
_image_feature_source
))
new_x
.
append
(
x_feat_dict
[
_image_feature_source
])
x
=
torch
.
cat
(
new_x
,
dim
=
1
)
x
=
x
@
self
.
image_projection
x
=
self
.
image_proj_norm
(
x
)
return
x
def
_process_image_input
(
self
,
image_input
:
Florence2ImagePixelInputs
)
->
torch
.
Tensor
:
assert
image_input
[
"type"
]
==
"pixel_values"
pixel_values
=
image_input
[
"data"
]
return
self
.
_encode_image
(
pixel_values
)
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
torch
.
Tensor
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
NestedTensors
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
pad_token_id
)
return
inputs_embeds
def
forward
(
def
forward
(
self
,
self
,
...
@@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module):
...
@@ -216,8 +1082,19 @@ class Florence2ForConditionalGeneration(nn.Module):
Returns:
Returns:
Output torch.Tensor
Output torch.Tensor
"""
"""
return
self
.
language_model
(
input_ids
,
positions
,
encoder_input_ids
,
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
encoder_positions
)
if
encoder_input_ids
.
numel
()
>
0
or
vision_embeddings
is
not
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
encoder_input_ids
,
vision_embeddings
)
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
def
compute_logits
(
self
,
self
,
...
@@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module):
...
@@ -236,9 +1113,5 @@ class Florence2ForConditionalGeneration(nn.Module):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
skip_prefixes
=
[
loader
=
AutoWeightsLoader
(
self
)
'image_projection'
,
"vision_tower"
,
"image_proj_norm"
,
"image_pos_embed"
,
"visual_temporal_embed"
]
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
skip_prefixes
)
return
loader
.
load_weights
(
weights
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/registry.py
View file @
edf309eb
...
@@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
...
@@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
}
}
_EMBEDDING_MODELS
=
{
_EMBEDDING_MODELS
=
{
...
@@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
...
@@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"Qwen2AudioForConditionalGeneration"
:
(
"qwen2_audio"
,
"Qwen2AudioForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
# [Encoder-decoder]
# [Encoder-decoder]
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
}
}
...
...
vllm/multimodal/processing.py
View file @
edf309eb
...
@@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
create_decoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
"""Create input prompt for the decoder."""
return
prompt
def
apply
(
def
apply
(
self
,
self
,
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
...
@@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
)
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
if
isinstance
(
prompt
,
str
):
decoder_prompt
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
decoder_prompt
=
prompt
if
isinstance
(
decoder_prompt
,
str
):
decoder_prompt_ids
=
encode_tokens
(
tokenizer
,
decoder_prompt_ids
=
encode_tokens
(
tokenizer
,
prompt
,
decoder_
prompt
,
add_special_tokens
=
False
)
add_special_tokens
=
False
)
else
:
else
:
decoder_prompt
=
decode_
tokens
(
tokenizer
,
prompt
)
decoder_prompt
_ids
=
decode
r
_prompt
decoder_prompt
_ids
=
prompt
decoder_prompt
=
decode_tokens
(
tokenizer
,
decoder_
prompt
)
mm_inputs
=
MultiModalEncDecInputs
(
mm_inputs
=
MultiModalEncDecInputs
(
encoder_prompt
=
encoder_inputs
[
"prompt"
],
encoder_prompt
=
encoder_inputs
[
"prompt"
],
...
...
vllm/multimodal/profiling.py
View file @
edf309eb
...
@@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
"and/or reduce `mm_counts`."
,
seq_len
,
total_len
,
total_placeholders_by_modality
)
total_placeholders_by_modality
)
num_tokens_to_pad
=
max
(
total_len
,
seq_len
)
-
total_len
prompt_token_ids
.
extend
([
0
]
*
num_tokens_to_pad
)
return
DummyData
(
return
DummyData
(
seq_data
=
SequenceData
.
from_prompt_token_counts
(
seq_data
=
SequenceData
.
from_seqs
(
prompt_token_ids
),
(
0
,
max
(
seq_len
,
total_len
))),
multi_modal_data
=
None
,
multi_modal_data
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_placeholders
=
None
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment