Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
759ef49b
Unverified
Commit
759ef49b
authored
Sep 15, 2025
by
Woosuk Kwon
Committed by
GitHub
Sep 15, 2025
Browse files
Remove V0 Encoder-Decoder Support (#24907)
Signed-off-by:
Woosuk Kwon
<
woosuk@thinkingmachines.ai
>
parent
5206ab20
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
8 additions
and
6226 deletions
+8
-6226
tests/models/multimodal/generation/test_florence2.py
tests/models/multimodal/generation/test_florence2.py
+0
-147
tests/models/multimodal/generation/test_mllama.py
tests/models/multimodal/generation/test_mllama.py
+0
-768
tests/models/multimodal/processing/test_common.py
tests/models/multimodal/processing/test_common.py
+0
-5
tests/models/multimodal/processing/test_mllama.py
tests/models/multimodal/processing/test_mllama.py
+0
-72
tests/models/multimodal/processing/test_tensor_schema.py
tests/models/multimodal/processing/test_tensor_schema.py
+0
-1
tests/models/registry.py
tests/models/registry.py
+1
-15
tests/models/test_initialization.py
tests/models/test_initialization.py
+0
-4
tests/models/test_registry.py
tests/models/test_registry.py
+0
-1
tests/test_config.py
tests/test_config.py
+1
-2
tests/utils_/test_utils.py
tests/utils_/test_utils.py
+0
-28
tests/v1/test_oracle.py
tests/v1/test_oracle.py
+0
-21
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+0
-648
vllm/config/__init__.py
vllm/config/__init__.py
+2
-9
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+1
-1
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+0
-1319
vllm/model_executor/models/donut.py
vllm/model_executor/models/donut.py
+0
-381
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+0
-1097
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+0
-1697
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+2
-9
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+1
-1
No files found.
tests/models/multimodal/generation/test_florence2.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
pytest
from
PIL
import
Image
from
vllm.inputs.data
import
ExplicitEncoderDecoderPrompt
,
TextPrompt
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
IMAGE_ASSETS
,
HfRunner
,
ImageTestAssets
,
VllmRunner
from
...utils
import
check_logprobs_close
MODELS
=
[
"microsoft/Florence-2-base"
]
# Florence-2 model repo's tokenizer config is missing some special tokens.
# Therefore, we use a converted tokenizer from a forked repo
TOKENIZER
=
"Isotr0py/Florence-2-tokenizer"
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<OD>"
,
# special task token which will output special tokens
"cherry_blossom"
:
"Describe in detail what is shown in the image."
,
})
def
get_hf_images_prompts
(
prompts_
:
list
[
ExplicitEncoderDecoderPrompt
[
str
,
TextPrompt
]],
)
->
tuple
[
list
[
ExplicitEncoderDecoderPrompt
[
str
,
str
]],
list
[
Image
.
Image
]]:
prompts
,
images
=
[],
[]
for
prompt
in
prompts_
:
encoder_prompt
=
prompt
[
"encoder_prompt"
]
prompts
.
append
(
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
[
"prompt"
],
decoder_prompt
=
None
,
))
images
.
append
(
encoder_prompt
[
"multi_modal_data"
][
"image"
])
return
prompts
,
images
def
hf_to_vllm_output
(
hf_output
:
tuple
[
list
[
int
],
str
,
Optional
[
SampleLogprobs
]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids
,
output_str
,
out_logprobs
=
hf_output
output_str
=
output_str
.
replace
(
"</s>"
,
""
).
replace
(
"<s>"
,
""
)
return
output_ids
,
output_str
,
out_logprobs
def
run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
inputs
:
list
[
list
[
ExplicitEncoderDecoderPrompt
]],
model
:
str
,
*
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
)
->
None
:
with
vllm_runner
(
model
,
max_num_seqs
=
8
,
tokenizer_name
=
TOKENIZER
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
enforce_eager
=
True
)
as
vllm_model
:
vllm_outputs_per_case
=
[
vllm_model
.
generate_encoder_decoder_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
skip_special_tokens
=
False
,
)
for
prompts
in
inputs
]
hf_inputs
=
[
get_hf_images_prompts
(
prompts
)
for
prompts
in
inputs
]
with
hf_runner
(
model
,
dtype
=
dtype
,
skip_tokenizer_init
=
True
)
as
hf_model
:
hf_model
.
model
.
get_output_embeddings
=
lambda
:
\
hf_model
.
model
.
language_model
.
lm_head
hf_outputs_per_case
=
[
hf_model
.
generate_encoder_decoder_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
hf_inputs
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_case
,
vllm_outputs_per_case
):
check_logprobs_close
(
outputs_0_lst
=
[
hf_to_vllm_output
(
output
)
for
output
in
hf_outputs
],
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
num_outputs_0_skip_tokens
=
1
,
)
# FIXME: https://github.com/huggingface/transformers/issues/38358
@
pytest
.
mark
.
skip
(
"Model initialization fails"
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"size_factors"
,
[
# No image
[],
# Single-scale
[
1.0
],
# Single-scale, batched
[
1.0
,
1.0
,
1.0
],
# Multi-scale
[
0.25
,
0.5
,
1.0
],
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
image_assets
:
ImageTestAssets
,
model
:
str
,
size_factors
:
list
[
int
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
)
->
None
:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
inputs_per_image
=
[[
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
TextPrompt
(
prompt
=
prompt
,
multi_modal_data
=
{
"image"
:
rescale_image_size
(
image
,
factor
)}),
decoder_prompt
=
None
,
)
for
factor
in
size_factors
]
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
run_test
(
hf_runner
,
vllm_runner
,
inputs_per_image
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
tests/models/multimodal/generation/test_mllama.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
,
overload
import
pytest
import
torch
from
packaging.version
import
Version
from
transformers
import
AutoConfig
,
AutoModelForImageTextToText
,
AutoTokenizer
from
transformers
import
__version__
as
TRANSFORMERS_VERSION
from
vllm
import
LLM
,
SamplingParams
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
from
vllm.model_executor.models.mllama
import
MllamaForConditionalGeneration
from
vllm.multimodal.image
import
rescale_image_size
from
vllm.sequence
import
SampleLogprobs
from
....conftest
import
(
IMAGE_ASSETS
,
HfRunner
,
ImageTestAssets
,
PromptImageInput
,
VllmRunner
)
from
....quantization.utils
import
is_quant_method_supported
from
....utils
import
(
create_new_process_for_each_test
,
large_gpu_test
,
multi_gpu_test
)
from
...utils
import
check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT
=
3
MLLAMA_IMAGE_TOKEN_ID
=
128256
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
HF_IMAGE_PROMPTS
=
IMAGE_ASSETS
.
prompts
({
"stop_sign"
:
"<|image|><|begin_of_text|>The meaning of the image is"
,
"cherry_blossom"
:
"<|image|><|begin_of_text|>The city is"
,
})
text_only_prompts
=
[
"The color of the sky is blue but sometimes it can also be"
,
]
models
=
[
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
]
# Indices for inputs
TEXT_ONLY
=
'0'
IMAGE_AT_BEG
=
'1'
IMAGE_AT_MIDDLE
=
'2'
TWO_IMAGES
=
'3'
# Input tokenized
prompt_data
=
{
# Tell me a story
TEXT_ONLY
:
[
41551
,
757
,
264
,
3446
],
# <|image|> What's the content of this image
IMAGE_AT_BEG
:
[
MLLAMA_IMAGE_TOKEN_ID
,
3639
,
596
,
279
,
2262
,
315
,
420
,
2217
,
220
],
# Hello <|image|>What' the content of this image
IMAGE_AT_MIDDLE
:
[
9906
,
220
,
MLLAMA_IMAGE_TOKEN_ID
,
3923
,
6
,
279
,
2262
,
315
,
420
,
2217
],
#<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501
TWO_IMAGES
:
[
MLLAMA_IMAGE_TOKEN_ID
,
3957
,
1070
,
264
,
37085
,
304
,
420
,
2217
,
30
,
MLLAMA_IMAGE_TOKEN_ID
,
3923
,
596
,
279
,
10065
,
304
,
420
,
2217
,
30
]
}
def
vllm_to_hf_output
(
vllm_output
:
tuple
[
list
[
int
],
str
,
Optional
[
SampleLogprobs
]],
model
:
str
):
"""Sanitize vllm output to be comparable with hf output."""
output_ids
,
output_str
,
out_logprobs
=
vllm_output
config
=
AutoConfig
.
from_pretrained
(
model
)
image_token_id
=
config
.
image_token_index
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
eos_token_id
=
tokenizer
.
eos_token_id
hf_output_ids
=
[
token_id
for
idx
,
token_id
in
enumerate
(
output_ids
)
if
token_id
!=
image_token_id
or
output_ids
[
idx
-
1
]
!=
image_token_id
]
hf_output_str
=
output_str
if
hf_output_ids
[
-
1
]
==
eos_token_id
:
hf_output_str
=
hf_output_str
+
tokenizer
.
decode
(
eos_token_id
)
return
hf_output_ids
,
hf_output_str
,
out_logprobs
def
_get_inputs
(
image_assets
:
ImageTestAssets
,
*
,
size_factors
:
Optional
[
list
[
float
]]
=
None
,
sizes
:
Optional
[
list
[
tuple
[
int
,
int
]]]
=
None
,
)
->
list
[
tuple
[
list
[
str
],
PromptImageInput
]]:
images
=
[
asset
.
pil_image
for
asset
in
image_assets
]
if
size_factors
is
not
None
:
inputs_per_image
=
[(
[
prompt
for
_
in
size_factors
],
[
rescale_image_size
(
image
,
factor
)
for
factor
in
size_factors
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
elif
sizes
is
not
None
:
inputs_per_image
=
[(
[
prompt
if
size
is
not
None
else
text_only_prompts
[
0
]
for
size
in
sizes
],
[
image
.
resize
(
size
)
if
size
is
not
None
else
None
for
size
in
sizes
],
)
for
image
,
prompt
in
zip
(
images
,
HF_IMAGE_PROMPTS
)]
if
len
(
sizes
)
==
0
:
inputs_per_image
.
append
(
(
text_only_prompts
,
[
None
]
*
len
(
text_only_prompts
)))
else
:
raise
ValueError
(
"You must provide either `size_factors` or `sizes`"
)
return
inputs_per_image
@
overload
def
run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
image_assets
:
ImageTestAssets
,
model
:
str
,
*
,
size_factors
:
list
[
float
],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
...
@
overload
def
run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
image_assets
:
ImageTestAssets
,
model
:
str
,
*
,
sizes
:
list
[
tuple
[
int
,
int
]],
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
...
def
run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
image_assets
:
ImageTestAssets
,
model
:
str
,
*
,
size_factors
:
Optional
[
list
[
float
]]
=
None
,
sizes
:
Optional
[
list
[
tuple
[
int
,
int
]]]
=
None
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
_run_test
(
hf_runner
,
vllm_runner
,
_get_inputs
(
image_assets
,
size_factors
=
size_factors
,
sizes
=
sizes
),
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
)
def
_run_test
(
hf_runner
:
type
[
HfRunner
],
vllm_runner
:
type
[
VllmRunner
],
inputs
:
list
[
tuple
[
list
[
str
],
PromptImageInput
]],
model
:
str
,
*
,
dtype
:
str
,
max_tokens
:
int
,
num_logprobs
:
int
,
tensor_parallel_size
:
int
,
distributed_executor_backend
:
Optional
[
str
]
=
None
,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with
vllm_runner
(
model
,
dtype
=
dtype
,
max_model_len
=
19212
,
# 3 max size images
max_num_seqs
=
3
,
tensor_parallel_size
=
tensor_parallel_size
,
distributed_executor_backend
=
distributed_executor_backend
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
vllm_outputs_per_image
=
[
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs
]
with
hf_runner
(
model
,
dtype
=
dtype
,
model_kwargs
=
{
"device_map"
:
"auto"
},
auto_cls
=
AutoModelForImageTextToText
)
as
hf_model
:
hf_outputs_per_image
=
[
hf_model
.
generate_greedy_logprobs_limit
(
prompts
,
max_tokens
,
num_logprobs
=
num_logprobs
,
images
=
images
)
for
prompts
,
images
in
inputs
]
for
hf_outputs
,
vllm_outputs
in
zip
(
hf_outputs_per_image
,
vllm_outputs_per_image
):
check_logprobs_close
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
[
vllm_to_hf_output
(
vllm_output
,
model
)
for
vllm_output
in
vllm_outputs
],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
@
pytest
.
fixture
(
autouse
=
True
)
def
clear_cache
():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend
.
cache_clear
()
# Clear the cache
yield
# This allows the test to run
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"sizes"
,
[
# Text only
[],
# Single-size
[(
512
,
512
)],
# Single-size, batched
[(
512
,
512
),
(
512
,
512
),
(
512
,
512
)],
# Multi-size, batched
[(
512
,
512
),
(
1024
,
512
),
(
1536
,
512
),
(
2048
,
512
),
(
512
,
1024
),
(
1024
,
1024
),
(
512
,
1536
),
(
512
,
2028
)],
# Multi-size, batched, including text only
[(
512
,
512
),
(
1024
,
512
),
(
1536
,
512
),
(
2048
,
512
),
(
512
,
1024
),
(
1024
,
1024
),
(
512
,
1536
),
(
512
,
2028
),
None
],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.55.2"
),
reason
=
"Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083"
)
def
test_models_single_leading_image
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
,
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
sizes
=
sizes
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.55.2"
),
reason
=
"Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083"
)
def
test_models_multi_leading_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images."
,
# noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images."
,
# noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images."
,
# noqa: E501
],
[
[
stop_sign
,
cherry_blossom
],
# Images with different sizes.
[
stop_sign
.
resize
((
512
,
512
)),
stop_sign
,
],
[
stop_sign
,
stop_sign
.
resize
((
512
,
1536
)),
cherry_blossom
.
resize
((
512
,
1024
)),
],
])]
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.55.2"
),
reason
=
"Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083"
)
def
test_models_interleaved_images
(
hf_runner
,
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
cherry_blossom
=
image_assets
[
1
].
pil_image
inputs
=
[(
[
"<|begin_of_text|>The content of the image <|image|> is"
,
# noqa: E501
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, "
# noqa: E501
"which is a stop sign and which is a cherry blossom?"
,
# noqa: E501
],
[
[
stop_sign
],
[
stop_sign
,
cherry_blossom
],
])]
with
global_force_attn_backend_context_manager
(
attn_backend
):
if
attn_backend
==
_Backend
.
FLASH_ATTN
:
# Flash Attention works only with bfloat16 data-type
dtype
=
'bfloat16'
_run_test
(
hf_runner
,
vllm_runner
,
inputs
,
model
,
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
1
,
)
@
create_new_process_for_each_test
()
@
multi_gpu_test
(
num_gpus
=
2
)
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
"mp"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
64
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
skipif
(
Version
(
TRANSFORMERS_VERSION
)
<=
Version
(
"4.55.2"
),
reason
=
"Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083"
)
def
test_models_distributed
(
hf_runner
,
vllm_runner
,
image_assets
,
distributed_executor_backend
,
model
,
dtype
,
max_tokens
,
num_logprobs
,
)
->
None
:
run_test
(
hf_runner
,
vllm_runner
,
image_assets
,
model
=
model
,
size_factors
=
[
0.25
,
0.5
,
1.0
],
dtype
=
dtype
,
max_tokens
=
max_tokens
,
num_logprobs
=
num_logprobs
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
distributed_executor_backend
,
)
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"bitsandbytes"
),
reason
=
'bitsandbytes is not supported on this GPU type.'
)
def
test_bnb_regression
(
image_assets
:
ImageTestAssets
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
):
stop_sign
=
image_assets
[
0
].
pil_image
prompts
=
[
{
"prompt"
:
"<|begin_of_text|>The content of the image <|image|> is"
,
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
{
"prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
},
]
# Test regression about QKVCrossParallelLinear
llm
=
LLM
(
model
=
model
,
dtype
=
dtype
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
quantization
=
"bitsandbytes"
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
max_tokens
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
assert
outputs
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
def
test_explicit_implicit_prompt
(
image_assets
:
ImageTestAssets
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
):
stop_sign
=
image_assets
[
0
].
pil_image
# yapf: disable
prompts
=
[
# explicit prompt
{
"encoder_prompt"
:
{
"prompt"
:
"<|image|>"
,
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
"decoder_prompt"
:
{
"prompt_token_ids"
:
[
128000
,
791
,
2262
,
315
,
279
,
2217
,
220
,
128256
,
374
],
# noqa: E501
}
},
{
"encoder_prompt"
:
"Not <|image|>"
,
"decoder_prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
# implicit prompt
{
"prompt"
:
"<|begin_of_text|>The content of the image <|image|> is"
,
# noqa: E501
"multi_modal_data"
:
{
"image"
:
stop_sign
},
},
{
"prompt"
:
"The color of the sky is blue but sometimes it can also be"
,
# noqa: E501
},
]
# yapf: enable
llm
=
LLM
(
model
=
model
,
dtype
=
dtype
,
max_model_len
=
8192
,
max_num_seqs
=
2
,
tensor_parallel_size
=
1
,
)
sampling_params
=
SamplingParams
(
temperature
=
0
,
max_tokens
=
max_tokens
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
n_prompts
=
len
(
prompts
)
explicit_outputs
=
outputs
[:
n_prompts
//
2
]
implicit_outputs
=
outputs
[
n_prompts
//
2
:]
for
exp_output
,
imp_output
in
zip
(
explicit_outputs
,
implicit_outputs
):
assert
exp_output
.
outputs
[
0
].
text
==
imp_output
.
outputs
[
0
].
text
@
large_gpu_test
(
min_gb
=
48
)
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model"
,
models
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
128
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"attn_backend"
,
LIST_ENC_DEC_SUPPORTED_BACKENDS
)
def
test_regression
(
vllm_runner
,
image_assets
,
model
,
dtype
,
max_tokens
,
num_logprobs
,
attn_backend
:
_Backend
)
->
None
:
stop_sign
=
image_assets
[
0
].
pil_image
with
global_force_attn_backend_context_manager
(
attn_backend
),
vllm_runner
(
model
,
dtype
=
dtype
,
max_model_len
=
8192
,
max_num_seqs
=
4
,
tensor_parallel_size
=
1
,
limit_mm_per_prompt
=
{
"image"
:
_LIMIT_IMAGE_PER_PROMPT
})
as
vllm_model
:
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
# Number of groups of image tokens is greater than the number of images
# provided (the whitespace between the tags is necessary)
prompt
=
"<|begin_of_text|><|image|> <|image|> Compare the two images"
# noqa: E501
image
=
stop_sign
with
pytest
.
raises
(
ValueError
):
vllm_model
.
generate_greedy_logprobs
([
prompt
],
max_tokens
,
num_logprobs
,
images
=
[
image
])
# Batch of a text-only and image request that requires cross-attention
prompts
=
[
"What is the capital of spain?"
,
"Text before the image...<|image|>What is in the image?"
,
# noqa: E501
]
images
=
[
None
,
[
stop_sign
],
]
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
images
=
images
)
# Test the reverse order too for good measure
prompts
=
[
"<|begin_of_text|>Text before the image...<|image|>What is in the image?"
,
# noqa: E501
"<|begin_of_text|>Hello!"
,
]
images
=
[
[
stop_sign
],
None
,
]
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
images
=
images
)
# Mixed batch with text and images with different numbers of tiles
prompts
=
[
"<|begin_of_text|>Hello!"
,
"<|begin_of_text|>Some text before.<|image|>What is in the image?"
,
# noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?"
,
# noqa: E501
]
images
=
[
None
,
[
stop_sign
],
# smaller image must be 2nd for the repro
[
stop_sign
.
resize
((
448
,
448
))],
]
vllm_model
.
generate_greedy_logprobs
(
prompts
,
max_tokens
,
num_logprobs
,
images
=
images
)
class
DummyModel
:
image_token_id
=
MLLAMA_IMAGE_TOKEN_ID
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"input_indices_and_output"
,
# inputs, (cross_attention_mask, kv_range_for_decode)
[([
TEXT_ONLY
],
(
None
,
None
)),
([
IMAGE_AT_BEG
],
(
None
,
None
)),
([
TEXT_ONLY
,
IMAGE_AT_BEG
],
(
None
,
None
)),
([
IMAGE_AT_MIDDLE
],
((
10
,
12
),
[[
0
,
6
]])),
([
TEXT_ONLY
,
IMAGE_AT_MIDDLE
],
((
14
,
12
),
[[
0
,
6
]])),
([
TEXT_ONLY
,
IMAGE_AT_BEG
,
IMAGE_AT_MIDDLE
],
((
23
,
24
),
[[
0
,
6
],
[
6
,
12
]])),
([
IMAGE_AT_MIDDLE
,
TEXT_ONLY
],
((
14
,
12
),
[[
0
,
6
]])),
([
TWO_IMAGES
],
((
18
,
12
),
[[
6
,
12
]])),
([
TEXT_ONLY
,
TWO_IMAGES
],
((
22
,
12
),
[[
6
,
12
]]))])
def
test_get_cross_attention_mask
(
input_indices_and_output
)
->
None
:
input_indices
,
expected_output
=
input_indices_and_output
sequences
=
[
torch
.
tensor
(
prompt_data
[
i
])
for
i
in
input_indices
]
num_tiles
=
[[
2
,
2
]
if
i
!=
TEXT_ONLY
else
[]
for
i
in
input_indices
if
i
!=
TEXT_ONLY
]
input
=
torch
.
cat
(
sequences
)
seq_lens
=
[
len
(
s
)
for
s
in
sequences
]
attn_data
=
FlashAttentionMetadata
(
seq_lens
=
seq_lens
,
# Dummy values
enable_kv_scales_calculation
=
False
,
num_prefills
=
0
,
num_prefill_tokens
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
0
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens_tensor
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
0
,
context_lens_tensor
=
None
,
block_tables
=
None
,
use_cuda_graph
=
False
,
)
dummy
=
DummyModel
()
cross_attention_mask
,
kv_range_for_decode
=
MllamaForConditionalGeneration
\
.
get_cross_attention_mask
(
dummy
,
input
,
attn_data
,
num_tiles
=
num_tiles
,
num_tokens_per_tile
=
3
,
dtype
=
torch
.
bfloat16
)
expected_cross_attention_mask
,
expected_kv_range_for_decode
=
\
expected_output
assert
kv_range_for_decode
==
expected_kv_range_for_decode
if
expected_cross_attention_mask
is
not
None
:
assert
cross_attention_mask
is
not
None
assert
cross_attention_mask
.
shape
==
expected_cross_attention_mask
else
:
assert
cross_attention_mask
is
None
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"input_indices"
,
[[
TEXT_ONLY
],
[
IMAGE_AT_BEG
],
[
TEXT_ONLY
,
IMAGE_AT_BEG
],
[
IMAGE_AT_MIDDLE
],
[
TEXT_ONLY
,
IMAGE_AT_MIDDLE
],
[
TEXT_ONLY
,
IMAGE_AT_BEG
,
IMAGE_AT_MIDDLE
],
[
IMAGE_AT_MIDDLE
,
TEXT_ONLY
],
[
TWO_IMAGES
],
[
TEXT_ONLY
,
TWO_IMAGES
]])
def
test_get_full_text_row_masked_out_mask
(
input_indices
)
->
None
:
sequences
=
[
torch
.
tensor
(
prompt_data
[
i
])
for
i
in
input_indices
]
seq_lens
=
[
len
(
s
)
for
s
in
sequences
]
num_prefill_tokens
=
sum
(
seq_lens
)
# TEXT_ONLY is zero, so it will be masked out,
# other instances should not be.
encoder_seq_lens
=
[
int
(
i
)
for
i
in
input_indices
]
attn_data
=
FlashAttentionMetadata
(
seq_lens
=
seq_lens
,
encoder_seq_lens
=
encoder_seq_lens
,
num_prefill_tokens
=
num_prefill_tokens
,
# Dummy values
enable_kv_scales_calculation
=
False
,
num_prefills
=
0
,
num_decode_tokens
=
0
,
slot_mapping
=
0
,
multi_modal_placeholder_index_maps
=
None
,
seq_lens_tensor
=
0
,
max_prefill_seq_len
=
0
,
max_decode_seq_len
=
0
,
context_lens_tensor
=
None
,
block_tables
=
None
,
use_cuda_graph
=
False
,
)
dummy
=
DummyModel
()
full_text_row_masked_out_mask
=
MllamaForConditionalGeneration
\
.
get_full_text_row_masked_out_mask
(
dummy
,
attn_data
,
torch
.
get_default_device
())
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
squeeze
()
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
tolist
()
idx
=
0
assert
len
(
full_text_row_masked_out_mask
)
==
num_prefill_tokens
for
i
,
seq_len
in
enumerate
(
seq_lens
):
must_be_masked
=
input_indices
[
i
]
!=
TEXT_ONLY
for
_
in
range
(
seq_len
):
assert
full_text_row_masked_out_mask
[
idx
]
==
must_be_masked
,
\
f
"full_text_row_masked_out_mask[
{
idx
}
] must be "
\
f
"'
{
must_be_masked
}
' "
idx
+=
1
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"encoder_seq_lens, num_tiles, expected"
,
[
([
6404
],
[[
4
]],
[
6404
]),
([
0
,
6404
],
[[
4
]],
[
6404
]),
([
0
,
1601
,
8005
],
[[
1
],
[
4
,
1
]],
[
1601
,
8005
]),
([
0
,
19212
,
0
,
3202
],
[[
4
,
4
,
4
],
[
2
]],
[
19212
,
3202
]),
])
def
test_parse_and_validate_encoder_lens
(
encoder_seq_lens
,
num_tiles
,
expected
)
->
None
:
dummy
=
DummyModel
()
num_tokens_per_tile
=
1601
actual_encoder_seq_lens
=
MllamaForConditionalGeneration
\
.
_get_and_validate_encoder_lens
(
dummy
,
encoder_seq_lens
,
num_tiles
,
num_tokens_per_tile
,
)
assert
actual_encoder_seq_lens
==
expected
,
\
f
"Expected
{
expected
}
but got
{
actual_encoder_seq_lens
}
"
tests/models/multimodal/processing/test_common.py
View file @
759ef49b
...
@@ -167,8 +167,6 @@ def _test_processing_correctness(
...
@@ -167,8 +167,6 @@ def _test_processing_correctness(
# incorrect token ids. So we need use `add_special_tokens=False` here
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES
=
{
_ADD_SPECIAL_TOKENS_OVERRIDES
=
{
"donut"
:
False
,
"mllama"
:
False
,
"ovis"
:
False
,
"ovis"
:
False
,
"ovis2_5"
:
False
,
"ovis2_5"
:
False
,
"paligemma"
:
False
,
"paligemma"
:
False
,
...
@@ -278,9 +276,7 @@ def _test_processing_correctness_one(
...
@@ -278,9 +276,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b"
,
"facebook/chameleon-7b"
,
"CohereLabs/command-a-vision-07-2025"
,
"CohereLabs/command-a-vision-07-2025"
,
"deepseek-ai/deepseek-vl2-tiny"
,
"deepseek-ai/deepseek-vl2-tiny"
,
"naver-clova-ix/donut-base-finetuned-docvqa"
,
"baidu/ERNIE-4.5-VL-28B-A3B-PT"
,
"baidu/ERNIE-4.5-VL-28B-A3B-PT"
,
"microsoft/Florence-2-base"
,
"adept/fuyu-8b"
,
"adept/fuyu-8b"
,
"google/gemma-3-4b-it"
,
"google/gemma-3-4b-it"
,
"google/gemma-3n-E2B-it"
,
"google/gemma-3n-E2B-it"
,
...
@@ -305,7 +301,6 @@ def _test_processing_correctness_one(
...
@@ -305,7 +301,6 @@ def _test_processing_correctness_one(
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/llava-v1.6-mistral-7b-hf"
,
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
"llava-hf/LLaVA-NeXT-Video-7B-hf"
,
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
,
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
,
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
"TIGER-Lab/Mantis-8B-siglip-llama3"
,
"mispeech/midashenglm-7b"
,
"mispeech/midashenglm-7b"
,
"openbmb/MiniCPM-Llama3-V-2_5"
,
"openbmb/MiniCPM-Llama3-V-2_5"
,
...
...
tests/models/multimodal/processing/test_mllama.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mllama's multimodal preprocessing and profiling."""
import
pytest
from
transformers
import
MllamaConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
...utils
import
build_model_context
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"meta-llama/Llama-3.2-11B-Vision-Instruct"
])
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
[
4096
,
8192
,
25600
,
131072
])
@
pytest
.
mark
.
parametrize
(
"max_num_seqs"
,
[
1
,
2
,
8
])
def
test_profiling
(
model_id
:
str
,
max_model_len
:
int
,
max_num_seqs
:
int
,
):
# regression test for https://github.com/vllm-project/vllm/issues/13929
from
vllm.model_executor.models.mllama
import
calc_token_per_chunk
model_config_kwargs
=
{
"max_model_len"
:
max_model_len
,
}
ctx
=
build_model_context
(
model_id
,
model_config_kwargs
=
model_config_kwargs
,
limit_mm_per_prompt
=
{
"image"
:
1
},
)
mm_config
=
ctx
.
get_mm_config
()
processor
=
MULTIMODAL_REGISTRY
.
create_processor
(
ctx
.
model_config
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_encoder_data
=
profiler
.
get_encoder_dummy_data
(
max_model_len
,
mm_counts
=
mm_config
.
limit_per_prompt
,
)
dummy_mm_data
=
processor
.
dummy_inputs
.
get_dummy_processor_inputs
(
max_model_len
,
mm_counts
=
mm_config
.
limit_per_prompt
,
)
hf_config
=
ctx
.
get_hf_config
(
MllamaConfig
)
image_size
=
hf_config
.
vision_config
.
image_size
encoder_seq_lens
=
[
len
(
dummy_encoder_data
.
prompt_token_ids
)
]
*
max_num_seqs
mm_data
=
processor
.
apply
(
prompt
=
dummy_mm_data
.
prompt
,
mm_data
=
dummy_mm_data
.
mm_data
,
hf_processor_mm_kwargs
=
dict
(),
)[
"mm_kwargs"
].
get_data
()
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles
=
[[
t
]
for
t
in
mm_data
.
pop
(
"num_tiles"
)]
num_tokens_per_tile
=
calc_token_per_chunk
(
image_size
)
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
# simulate mllama image-present prefill.
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
encoder_seq_lens
):
assert
actual_len
>=
last_group_len
tests/models/multimodal/processing/test_tensor_schema.py
View file @
759ef49b
...
@@ -31,7 +31,6 @@ from ...utils import dummy_hf_overrides
...
@@ -31,7 +31,6 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP
=
{
ARCH_TO_SKIP
=
{
"MolmoForCausalLM"
:
"incompatible requirements"
,
"MolmoForCausalLM"
:
"incompatible requirements"
,
"Florence2ForConditionalGeneration"
:
"not supported in V1"
,
}
}
ARCH_NEEDS_EXTRAS
=
[
ARCH_NEEDS_EXTRAS
=
[
"InternVLChatModel"
,
"InternVLChatModel"
,
...
...
tests/models/registry.py
View file @
759ef49b
...
@@ -354,11 +354,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
...
@@ -354,11 +354,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiMoForCausalLM"
:
_HfExamplesInfo
(
"XiaomiMiMo/MiMo-7B-RL"
,
"MiMoForCausalLM"
:
_HfExamplesInfo
(
"XiaomiMiMo/MiMo-7B-RL"
,
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Dots1ForCausalLM"
:
_HfExamplesInfo
(
"rednote-hilab/dots.llm1.inst"
),
"Dots1ForCausalLM"
:
_HfExamplesInfo
(
"rednote-hilab/dots.llm1.inst"
),
# [Encoder-decoder]
"BartModel"
:
_HfExamplesInfo
(
"facebook/bart-base"
),
"BartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/bart-large-cnn"
),
"MBartForConditionalGeneration"
:
_HfExamplesInfo
(
"facebook/mbart-large-en-ro"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"MBartForConditionalGeneration"
]}),
# noqa: E501
}
}
_EMBEDDING_EXAMPLE_MODELS
=
{
_EMBEDDING_EXAMPLE_MODELS
=
{
...
@@ -496,7 +491,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -496,7 +491,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code
=
True
),
trust_remote_code
=
True
),
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
# noqa: E501
"Llama4ForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
,
# noqa: E501
max_model_len
=
10240
,
max_model_len
=
10240
,
extras
=
{
"llama-guard-4"
:
"meta-llama/Llama-Guard-4-12B"
},
# noqa: E501
extras
=
{
"llama-guard-4"
:
"meta-llama/Llama-Guard-4-12B"
},
# noqa: E501
),
),
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
"LlavaForConditionalGeneration"
:
_HfExamplesInfo
(
"llava-hf/llava-1.5-7b-hf"
,
extras
=
{
"mistral"
:
"mistral-community/pixtral-12b"
,
# noqa: E501
extras
=
{
"mistral"
:
"mistral-community/pixtral-12b"
,
# noqa: E501
...
@@ -583,15 +578,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
...
@@ -583,15 +578,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online
=
False
,
is_available_online
=
False
,
),
),
# [Encoder-decoder]
# [Encoder-decoder]
"DonutForConditionalGeneration"
:
_HfExamplesInfo
(
"naver-clova-ix/donut-base-finetuned-docvqa"
,
# noqa: E501
hf_overrides
=
{
"architectures"
:
[
"DonutForConditionalGeneration"
],
"model_type"
:
"donut"
},
# noqa: E501
extras
=
{
"dolphin"
:
"ByteDance/Dolphin"
}),
# noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration"
:
_HfExamplesInfo
(
"microsoft/Florence-2-base"
,
# noqa: E501
tokenizer
=
"Isotr0py/Florence-2-tokenizer"
,
# noqa: E501
trust_remote_code
=
True
),
# noqa: E501
"MllamaForConditionalGeneration"
:
_HfExamplesInfo
(
"meta-llama/Llama-3.2-11B-Vision-Instruct"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
_HfExamplesInfo
(
"openai/whisper-large-v3"
),
# noqa: E501
# [Cross-encoder]
# [Cross-encoder]
"JinaVLForRanking"
:
_HfExamplesInfo
(
"jinaai/jina-reranker-m0"
),
# noqa: E501
"JinaVLForRanking"
:
_HfExamplesInfo
(
"jinaai/jina-reranker-m0"
),
# noqa: E501
...
...
tests/models/test_initialization.py
View file @
759ef49b
...
@@ -92,10 +92,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
...
@@ -92,10 +92,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3.
# L4 supports FA3.
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"TRITON_ATTN_VLLM_V1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"TRITON_ATTN_VLLM_V1"
)
if
model_arch
==
"Florence2ForConditionalGeneration"
:
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest
.
skip
(
"Skipping Florence2ForConditionalGeneration"
)
if
model_arch
==
"WhisperForConditionalGeneration"
:
if
model_arch
==
"WhisperForConditionalGeneration"
:
m
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
m
.
setenv
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
"spawn"
)
LLM
(
LLM
(
...
...
tests/models/test_registry.py
View file @
759ef49b
...
@@ -50,7 +50,6 @@ def test_registry_imports(model_arch):
...
@@ -50,7 +50,6 @@ def test_registry_imports(model_arch):
@
create_new_process_for_each_test
()
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model_arch,is_mm,init_cuda,is_ce"
,
[
@
pytest
.
mark
.
parametrize
(
"model_arch,is_mm,init_cuda,is_ce"
,
[
(
"LlamaForCausalLM"
,
False
,
False
,
False
),
(
"LlamaForCausalLM"
,
False
,
False
,
False
),
(
"MllamaForConditionalGeneration"
,
True
,
False
,
False
),
(
"LlavaForConditionalGeneration"
,
True
,
True
,
False
),
(
"LlavaForConditionalGeneration"
,
True
,
True
,
False
),
(
"BertForSequenceClassification"
,
False
,
False
,
True
),
(
"BertForSequenceClassification"
,
False
,
False
,
True
),
(
"RobertaForSequenceClassification"
,
False
,
False
,
True
),
(
"RobertaForSequenceClassification"
,
False
,
False
,
True
),
...
...
tests/test_config.py
View file @
759ef49b
...
@@ -299,9 +299,8 @@ def test_rope_customization():
...
@@ -299,9 +299,8 @@ def test_rope_customization():
reason
=
"Encoder Decoder models not supported on ROCm."
)
reason
=
"Encoder Decoder models not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"is_encoder_decoder"
),
[
@
pytest
.
mark
.
parametrize
((
"model_id"
,
"is_encoder_decoder"
),
[
(
"facebook/opt-125m"
,
False
),
(
"facebook/opt-125m"
,
False
),
(
"
facebook/bart-base
"
,
True
),
(
"
openai/whisper-tiny
"
,
True
),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
False
),
(
"meta-llama/Llama-3.2-1B-Instruct"
,
False
),
(
"meta-llama/Llama-3.2-11B-Vision"
,
True
),
])
])
def
test_is_encoder_decoder
(
model_id
,
is_encoder_decoder
):
def
test_is_encoder_decoder
(
model_id
,
is_encoder_decoder
):
config
=
ModelConfig
(
model_id
)
config
=
ModelConfig
(
model_id
)
...
...
tests/utils_/test_utils.py
View file @
759ef49b
...
@@ -501,34 +501,6 @@ def test_bind_kv_cache_non_attention():
...
@@ -501,34 +501,6 @@ def test_bind_kv_cache_non_attention():
assert
ctx
[
'model.layers.28.attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
assert
ctx
[
'model.layers.28.attn'
].
kv_cache
[
0
]
is
kv_cache
[
1
]
def
test_bind_kv_cache_encoder_decoder
(
monkeypatch
:
pytest
.
MonkeyPatch
):
# V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
from
vllm.attention
import
Attention
,
AttentionType
# example from bart
ctx
=
{
'encoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER
),
'decoder.layers.0.encoder_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
ENCODER_DECODER
),
'decoder.layers.0.self_attn.attn'
:
Attention
(
32
,
128
,
0.1
,
attn_type
=
AttentionType
.
DECODER
),
}
kv_cache
=
[
torch
.
zeros
((
1
,
)),
]
encoder_kv_cache
=
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
bind_kv_cache
(
ctx
,
[
kv_cache
])
assert
ctx
[
'encoder.layers.0.self_attn.attn'
].
kv_cache
is
encoder_kv_cache
assert
ctx
[
'decoder.layers.0.encoder_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
assert
ctx
[
'decoder.layers.0.self_attn.attn'
].
kv_cache
[
0
]
is
kv_cache
[
0
]
def
test_bind_kv_cache_pp
():
def
test_bind_kv_cache_pp
():
with
patch
(
"vllm.utils.cuda_device_count_stateless"
,
lambda
:
2
):
with
patch
(
"vllm.utils.cuda_device_count_stateless"
,
lambda
:
2
):
# this test runs with 1 GPU, but we simulate 2 GPUs
# this test runs with 1 GPU, but we simulate 2 GPUs
...
...
tests/v1/test_oracle.py
View file @
759ef49b
...
@@ -9,24 +9,9 @@ from vllm import LLM
...
@@ -9,24 +9,9 @@ from vllm import LLM
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
UNSUPPORTED_MODELS_V1
=
[
"facebook/bart-large-cnn"
,
# encoder decoder
]
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
MODEL
=
"meta-llama/Llama-3.2-1B-Instruct"
@
pytest
.
mark
.
parametrize
(
"model"
,
UNSUPPORTED_MODELS_V1
)
def
test_reject_unsupported_models
(
monkeypatch
,
model
):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
args
=
AsyncEngineArgs
(
model
=
model
)
with
pytest
.
raises
(
NotImplementedError
):
_
=
args
.
create_engine_config
()
m
.
delenv
(
"VLLM_USE_V1"
)
def
test_reject_bad_config
(
monkeypatch
):
def
test_reject_bad_config
(
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
...
@@ -77,12 +62,6 @@ def test_enable_by_default_fallback(monkeypatch):
...
@@ -77,12 +62,6 @@ def test_enable_by_default_fallback(monkeypatch):
assert
envs
.
VLLM_USE_V1
assert
envs
.
VLLM_USE_V1
m
.
delenv
(
"VLLM_USE_V1"
)
m
.
delenv
(
"VLLM_USE_V1"
)
# Should fall back to V0 for supported model.
_
=
AsyncEngineArgs
(
model
=
UNSUPPORTED_MODELS_V1
[
0
]).
create_engine_config
()
assert
not
envs
.
VLLM_USE_V1
m
.
delenv
(
"VLLM_USE_V1"
)
def
test_v1_llm_by_default
(
monkeypatch
):
def
test_v1_llm_by_default
(
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
...
...
tests/worker/test_encoder_decoder_model_runner.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
import
pytest
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
BATCH_SIZES
=
[
1
,
4
,
16
,
64
,
256
]
def
_create_model_runner
(
model
:
str
,
*
args
,
**
kwargs
)
->
EncoderDecoderModelRunner
:
engine_args
=
EngineArgs
(
model
,
*
args
,
**
kwargs
)
engine_config
=
engine_args
.
create_engine_config
()
model_runner
=
EncoderDecoderModelRunner
(
vllm_config
=
engine_config
,
is_driver_worker
=
True
,
)
return
model_runner
@
pytest
.
mark
.
skipif
(
condition
=
current_platform
.
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
def
test_empty_seq_group
():
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
True
,
)
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
model_input
=
model_runner
.
_prepare_model_input_tensors
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
encoder_input_tokens
,
encoder_input_positions
,
attn_metadata
,
return_seq_lens
,
)
=
(
model_input
.
input_tokens
,
model_input
.
input_positions
,
model_input
.
encoder_input_tokens
,
model_input
.
encoder_input_positions
,
model_input
.
attn_metadata
,
model_input
.
seq_lens
,
)
assert
input_tokens
is
None
assert
input_positions
is
None
assert
encoder_input_tokens
is
None
assert
encoder_input_positions
is
None
assert
attn_metadata
is
None
assert
return_seq_lens
is
None
@
pytest
.
mark
.
skipif
(
condition
=
current_platform
.
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
def
test_prepare_prompt
(
batch_size
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
True
,
)
seq_lens
:
list
[
int
]
=
[]
encoder_seq_lens
:
list
[
int
]
=
[]
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
]}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
seq_data
.
get_len
()
seq_group_metadata_list
.
append
(
seq_group_metadata
)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify input metadata is correct for prompts.
# - Decoder attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
max
(
seq_lens
)
assert
attn_metadata
.
max_decode_seq_len
==
0
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
encoder_seq_lens
)
# Test decoder subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_loc
.
append
(
start_idx
)
assert
torch
.
equal
(
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
# Test decoder seq start locs & context lengths
assert
torch
.
equal
(
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
assert
torch
.
equal
(
attn_metadata
.
context_lens_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
device
=
device
),
)
# Verify block tables are correct for prompts
# - Decoder self-attention
expected
=
torch
.
tensor
(
[[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Cuda graph should not be used for prefill.
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
sum
(
seq_lens
)
assert
len
(
input_positions
)
==
sum
(
seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
sum
(
encoder_seq_lens
)
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the prefill phase
expected_selected_token_indices
=
[]
selected_token_start_idx
=
0
for
seq_len
in
seq_lens
:
# Compute the index offset of the final token in each
# prompt (recall that the prompts are concatenated)
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
seq_len
-
1
)
selected_token_start_idx
+=
seq_len
sampling_metadata
=
model_input
.
sampling_metadata
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
,
)
assert
torch
.
equal
(
actual
,
expected
)
@
pytest
.
mark
.
skipif
(
condition
=
current_platform
.
is_cpu
(),
reason
=
"CPU backend is currently "
"unsupported for encoder/ "
"decoder models"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"multiple_seqs_per_seq_group"
,
[
True
,
False
])
def
test_prepare_decode
(
batch_size
,
multiple_seqs_per_seq_group
):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
True
,
)
seq_lens
:
list
[
int
]
=
[]
encoder_seq_lens
:
list
[
int
]
=
[]
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
block_tables
=
{
0
:
[
1
],
1
:
[
3
]
}
if
multiple_seqs_per_seq_group
else
{
0
:
[
1
]
}
cross_block_table
=
[
2
]
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
,
1
:
seq_data
}
if
multiple_seqs_per_seq_group
else
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_lens
.
extend
(
[
seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
encoder_seq_lens
.
extend
(
[
encoder_seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
assert
return_seq_lens
==
seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify input metadata is correct for decode phase.
# - Decoder attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_decode_tokens
>
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
encoder_seq_lens
)
# Test decoder subquery start locs.
start_idx
=
0
start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
1
start_loc
.
append
(
start_idx
)
assert
torch
.
equal
(
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
# Test decoder seq start locs. Note that for normal prefill it is
# equivalent to query_start_loc.
start_idx
=
0
seq_start_loc
=
[
start_idx
]
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
# Test seq_start_loc and context lengths
assert
torch
.
equal
(
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
),
)
assert
torch
.
equal
(
attn_metadata
.
context_lens_tensor
,
torch
.
tensor
([
seq_len
-
1
for
seq_len
in
seq_lens
],
dtype
=
torch
.
int
,
device
=
device
))
# Verify block tables are correct for prompts
# - Decoder self-attention
flattened_block_tables
=
[
block_table
for
block_table
in
block_tables
.
values
()
]
expected
=
torch
.
tensor
(
flattened_block_tables
*
len
(
seq_group_metadata_list
),
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention
expected
=
torch
.
tensor
([
cross_block_table
for
seq_group_metadata
in
seq_group_metadata_list
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))
],
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
False
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
len
(
seq_lens
)
assert
len
(
input_positions
)
==
len
(
seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
0
assert
len
(
encoder_input_tokens
)
==
0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the decode phase
expected_selected_token_indices
=
[]
for
selected_token_start_idx
,
seq_len
in
enumerate
(
seq_lens
):
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# of the decoded tokens for a given sequence is 1 and
# the final index offset into a given sequence's
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices
.
append
(
selected_token_start_idx
)
sampling_metadata
=
model_input
.
sampling_metadata
actual
=
sampling_metadata
.
selected_token_indices
expected
=
torch
.
tensor
(
expected_selected_token_indices
,
device
=
actual
.
device
,
dtype
=
actual
.
dtype
,
)
assert
torch
.
equal
(
actual
,
expected
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
list
(
range
(
1
,
257
)))
@
pytest
.
mark
.
parametrize
(
"multiple_seqs_per_seq_group"
,
[
True
,
False
])
def
test_prepare_decode_cuda_graph
(
batch_size
,
multiple_seqs_per_seq_group
):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner
=
_create_model_runner
(
"facebook/bart-base"
,
seed
=
0
,
dtype
=
"float16"
,
max_num_batched_tokens
=
100000
,
max_num_seqs
=
100000
,
enable_chunked_prefill
=
False
,
enforce_eager
=
False
,
)
block_tables
=
{
0
:
[
1
],
1
:
[
3
]
}
if
multiple_seqs_per_seq_group
else
{
0
:
[
1
]
}
seq_lens
:
list
[
int
]
=
[]
encoder_seq_lens
:
list
[
int
]
=
[]
seq_group_metadata_list
:
list
[
SequenceGroupMetadata
]
=
[]
cross_block_table
=
[
2
]
expanded_batch_size
=
0
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
seq_data
=
{
0
:
seq_data
,
1
:
seq_data
}
if
multiple_seqs_per_seq_group
else
{
0
:
seq_data
},
sampling_params
=
SamplingParams
(
temperature
=
0
),
block_tables
=
block_tables
,
encoder_seq_data
=
encoder_seq_data
,
cross_block_table
=
cross_block_table
,
)
assert
seq_group_metadata
.
token_chunk_size
==
1
seq_lens
.
extend
(
[
seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
encoder_seq_lens
.
extend
(
[
encoder_seq_len
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))])
expanded_batch_size
=
expanded_batch_size
+
len
(
seq_group_metadata
.
seq_data
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
model_input
=
model_runner
.
prepare_model_input
(
seq_group_metadata_list
)
input_tokens
=
model_input
.
input_tokens
input_positions
=
model_input
.
input_positions
attn_metadata
=
model_input
.
attn_metadata
return_seq_lens
=
model_input
.
seq_lens
slot_mapping
=
attn_metadata
.
slot_mapping
encoder_input_tokens
=
model_input
.
encoder_input_tokens
encoder_input_positions
=
model_input
.
encoder_input_positions
cross_slot_mapping
=
attn_metadata
.
cross_slot_mapping
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size
=
model_runner
.
vllm_config
.
pad_for_cudagraph
(
expanded_batch_size
)
cuda_graph_pad_size
=
graph_batch_size
-
expanded_batch_size
padded_seq_lens
=
seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
padded_encoder_seq_lens
=
encoder_seq_lens
+
list
(
itertools
.
repeat
(
1
,
cuda_graph_pad_size
))
assert
return_seq_lens
==
padded_seq_lens
assert
len
(
slot_mapping
)
==
len
(
input_tokens
)
assert
len
(
cross_slot_mapping
)
==
len
(
encoder_input_tokens
)
# Verify attention metadata
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
==
0
assert
attn_metadata
.
num_decode_tokens
>
0
assert
torch
.
equal
(
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
padded_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
padded_seq_lens
assert
attn_metadata
.
max_prefill_seq_len
==
0
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
# - Encoder attention metadata
assert
attn_metadata
.
encoder_seq_lens
==
padded_encoder_seq_lens
assert
torch
.
equal
(
attn_metadata
.
encoder_seq_lens_tensor
,
torch
.
tensor
(
padded_encoder_seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
max_encoder_seq_len
==
max
(
padded_encoder_seq_lens
)
assert
attn_metadata
.
num_encoder_tokens
==
sum
(
padded_encoder_seq_lens
)
# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
flattened_block_tables
=
[
block_table
for
_
in
range
(
len
(
seq_group_metadata_list
))
for
block_table
in
block_tables
.
values
()
]
flattened_block_tables
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
flattened_block_tables
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
block_tables
,
expected
,
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected
=
[
cross_block_table
for
seq_group_metadata
in
seq_group_metadata_list
for
_
in
range
(
len
(
seq_group_metadata
.
seq_data
))
]
expected
.
extend
([[]
for
_
in
range
(
cuda_graph_pad_size
)])
expected
=
make_tensor_with_pad
(
expected
,
max_len
=
64
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
,
)
assert
torch
.
equal
(
attn_metadata
.
cross_block_tables
,
expected
,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert
attn_metadata
.
use_cuda_graph
is
True
# Verify the lengths of input tokens & positions
# - Decoder
assert
len
(
input_tokens
)
==
len
(
padded_seq_lens
)
assert
len
(
input_positions
)
==
len
(
padded_seq_lens
)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
input_tokens
,
input_positions
,
)
# - Encoder
assert
len
(
encoder_input_tokens
)
==
0
assert
len
(
encoder_input_tokens
)
==
0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert
torch
.
equal
(
encoder_input_tokens
,
encoder_input_positions
,
)
vllm/config/__init__.py
View file @
759ef49b
...
@@ -1201,11 +1201,8 @@ class ModelConfig:
...
@@ -1201,11 +1201,8 @@ class ModelConfig:
getattr
(
self
.
hf_config
,
"max_source_positions"
,
0
))
getattr
(
self
.
hf_config
,
"max_source_positions"
,
0
))
self
.
max_seq_len_to_capture
=
min
(
self
.
max_seq_len_to_capture
,
self
.
max_seq_len_to_capture
=
min
(
self
.
max_seq_len_to_capture
,
effective_max_seq_len
)
effective_max_seq_len
)
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
# CUDAGraph capture not supported for encoder-decoder models on ROCm
ROCM_UNSUPPORTED_MODELS
=
[
'mllama'
]
unsupported_rocm
=
self
.
is_encoder_decoder
unsupported_rocm
=
(
self
.
hf_config
.
model_type
in
ROCM_UNSUPPORTED_MODELS
or
self
.
is_encoder_decoder
)
if
(
unsupported_rocm
and
not
self
.
enforce_eager
if
(
unsupported_rocm
and
not
self
.
enforce_eager
and
current_platform
.
is_rocm
()):
and
current_platform
.
is_rocm
()):
...
@@ -1671,10 +1668,6 @@ class ModelConfig:
...
@@ -1671,10 +1668,6 @@ class ModelConfig:
@
property
@
property
def
is_encoder_decoder
(
self
)
->
bool
:
def
is_encoder_decoder
(
self
)
->
bool
:
"""Extract the HF encoder/decoder model flag."""
"""Extract the HF encoder/decoder model flag."""
"""
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
True to enable cross-attention
"""
return
is_encoder_decoder
(
self
.
hf_config
)
return
is_encoder_decoder
(
self
.
hf_config
)
@
property
@
property
...
...
vllm/engine/llm_engine.py
View file @
759ef49b
...
@@ -1789,7 +1789,7 @@ class LLMEngine:
...
@@ -1789,7 +1789,7 @@ class LLMEngine:
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
assert
isinstance
(
mm_processor
,
EncDecMultiModalProcessor
)
if
mm_processor
.
pad_dummy_encoder_prompt
:
if
mm_processor
.
pad_dummy_encoder_prompt
:
return
# Skip encoder length check for Whisper
and Donut
return
# Skip encoder length check for Whisper
if
model_config
.
is_multimodal_model
:
if
model_config
.
is_multimodal_model
:
suggestion
=
(
suggestion
=
(
...
...
vllm/model_executor/models/bart.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import
math
from
collections.abc
import
Iterable
from
typing
import
Optional
import
torch
from
torch
import
nn
from
transformers
import
BartConfig
from
transformers.utils
import
logging
from
vllm.attention
import
Attention
,
AttentionType
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config.lora
import
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVCrossParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsQuant
,
SupportsV0Only
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
cast_overflow_tensors
,
maybe_prefix
)
logger
=
logging
.
get_logger
(
__name__
)
def
get_bsz_seq_len
(
input_ids
):
shp
=
input_ids
.
shape
ndim
=
len
(
shp
)
if
ndim
==
1
:
return
1
,
input_ids
.
numel
()
else
:
return
shp
[:
2
]
class
BartLearnedPositionalEmbedding
(
VocabParallelEmbedding
):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self
.
offset
=
2
super
().
__init__
(
num_embeddings
+
self
.
offset
,
embedding_dim
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
return
super
().
forward
(
positions
+
self
.
offset
)
class
BartScaledWordEmbedding
(
VocabParallelEmbedding
):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
embed_scale
:
float
=
1.0
):
super
().
__init__
(
num_embeddings
,
embedding_dim
)
self
.
embed_scale
=
embed_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input_ids
)
*
self
.
embed_scale
class
BartParallelLMHead
(
ParallelLMHead
):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
embed_scale
:
float
=
1.0
):
super
().
__init__
(
num_embeddings
,
embedding_dim
)
self
.
embed_scale
=
embed_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
super
().
forward
(
input_ids
)
/
self
.
embed_scale
class
BartEncoderAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
self
.
num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartDecoderSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
self
.
num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
DECODER
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
:
int
,
num_heads
:
int
,
bias
:
bool
=
True
,
config
:
Optional
[
BartConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
d_model
=
config
.
d_model
self
.
embed_dim
=
embed_dim
self
.
total_num_heads
=
num_heads
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
head_dim
=
embed_dim
//
num_heads
self
.
config
=
config
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
scaling
=
self
.
head_dim
**-
0.5
# TP sharding sizes is accounted for within "*Parallel" layers.
self
.
qkv_proj
=
QKVCrossParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
,
quant_config
=
quant_config
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
quant_config
=
quant_config
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_world_size
if
self
.
total_num_kv_heads
>=
tp_world_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_world_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_world_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
self
.
num_heads
# No GQA in bart
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
q
,
k
,
v
=
self
.
qkv_proj
(
decoder_hidden_states
,
encoder_hidden_states
)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
class
BartEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
BartEncoderAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
encoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
ffn_hidden_size
=
self
.
embed_dim
ffn_intermediate_size
=
config
.
encoder_ffn_dim
ffn_has_bias
=
True
self
.
fc1
=
ColumnParallelLinear
(
ffn_hidden_size
,
ffn_intermediate_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
fc2
=
RowParallelLinear
(
ffn_intermediate_size
,
ffn_hidden_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
residual
=
hidden_states
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
if
hidden_states
.
dtype
==
torch
.
float16
and
(
torch
.
isinf
(
hidden_states
).
any
()
or
torch
.
isnan
(
hidden_states
).
any
()):
hidden_states
=
cast_overflow_tensors
(
hidden_states
)
return
hidden_states
class
BartDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
d_model
self
.
self_attn
=
BartDecoderSelfAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
config
.
decoder_attention_heads
,
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
activation_fn
=
get_act_fn
(
config
.
activation_function
)
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self
.
encoder_attn
=
BartCrossAttention
(
self
.
embed_dim
,
config
.
decoder_attention_heads
,
config
=
config
,
prefix
=
f
"
{
prefix
}
.encoder_attn"
,
)
self
.
encoder_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
ffn_hidden_size
=
self
.
embed_dim
ffn_intermediate_size
=
config
.
encoder_ffn_dim
ffn_has_bias
=
True
self
.
fc1
=
ColumnParallelLinear
(
ffn_hidden_size
,
ffn_intermediate_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
fc2
=
RowParallelLinear
(
ffn_intermediate_size
,
ffn_hidden_size
,
bias
=
ffn_has_bias
,
quant_config
=
quant_config
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
)
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual
=
decoder_hidden_states
# Self Attention
hidden_states
=
self
.
self_attn
(
hidden_states
=
decoder_hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
# Cross-Attention Block
residual
=
hidden_states
hidden_states
=
self
.
encoder_attn
(
decoder_hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
encoder_attn_layer_norm
(
hidden_states
)
# Fully Connected
residual
=
hidden_states
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
return
hidden_states
class
BartEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
embed_dim
=
config
.
d_model
self
.
max_source_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
embed_dim
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
)
self
.
layers
=
nn
.
ModuleList
([
BartEncoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
encoder_layers
)
])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
positions
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
)
return
hidden_states
class
BartDecoder
(
nn
.
Module
):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
self
.
max_target_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
)
self
.
layers
=
nn
.
ModuleList
(
[
BartDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
\
for
layer_idx
in
range
(
config
.
decoder_layers
)])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
else
:
decoder_positions
=
inputs_embeds
[:,
-
1
]
# embed positions
embed_pos
=
self
.
embed_positions
(
decoder_positions
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
# decoder layers
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
decoder_hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
return
hidden_states
class
BartModel
(
nn
.
Module
,
SupportsQuant
):
_tied_weights_keys
=
[
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
,
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
encoder
=
BartEncoder
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
decoder
=
BartDecoder
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states
=
None
if
encoder_input_ids
.
numel
()
>
0
:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
)
return
decoder_outputs
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
other_weights
=
[]
loaded_stacked_params
=
[]
model_params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
model_params_dict
:
continue
param
=
model_params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_stacked_params
.
append
(
name
)
break
else
:
if
name
in
model_params_dict
:
other_weights
.
append
((
name
,
loaded_weight
))
loader
=
AutoWeightsLoader
(
self
)
loaded_params
=
loader
.
load_weights
(
other_weights
)
loaded_params
.
update
(
loaded_stacked_params
)
return
loaded_params
class
BartForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
,
SupportsQuant
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"decoder."
:
"model.decoder."
,
"encoder."
:
"model.encoder."
,
"shared."
:
"model.shared."
},
orig_to_new_substr
=
{
"beta"
:
"bias"
,
"gamma"
:
"weight"
,
"LayerNorm"
:
"layernorm"
,
},
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
model
=
BartModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
lm_head
=
BartParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices.
Returns:
Output torch.Tensor
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
weights_tuple_list
=
list
(
weights
)
shared_embedding_weight
=
None
for
name
,
loaded_weight
in
weights_tuple_list
:
if
(
'shared.weight'
in
name
or
'encoder.embed_tokens.weight'
in
name
or
'decoder.embed_tokens.weight'
in
name
or
'lm_head.weight'
in
name
):
assert
shared_embedding_weight
is
None
,
(
"Conflicting embedding weights."
)
shared_embedding_weight
=
loaded_weight
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
([
"cls."
,
"pooler."
]),
)
loaded_params
=
loader
.
load_weights
(
weights_tuple_list
,
mapper
=
self
.
hf_to_vllm_mapper
)
if
shared_embedding_weight
is
not
None
:
weight_loader
=
getattr
(
self
.
lm_head
.
weight
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
self
.
lm_head
.
weight
,
shared_embedding_weight
)
self
.
model
.
encoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
self
.
model
.
decoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
loaded_params
.
update
({
'model.encoder.embed_tokens.weight'
,
'lm_head.weight'
,
'model.decoder.embed_tokens.weight'
})
return
loaded_params
class
MBartEncoderLayer
(
BartEncoderLayer
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual
=
hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
if
hidden_states
.
dtype
==
torch
.
float16
and
(
torch
.
isinf
(
hidden_states
).
any
()
or
torch
.
isnan
(
hidden_states
).
any
()):
hidden_states
=
cast_overflow_tensors
(
hidden_states
)
return
hidden_states
class
MBartDecoderLayer
(
BartDecoderLayer
):
def
forward
(
self
,
decoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
residual
=
decoder_hidden_states
hidden_states
=
self
.
self_attn_layer_norm
(
decoder_hidden_states
)
# Self Attention
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
# Cross-Attention Block
residual
=
hidden_states
hidden_states
=
self
.
encoder_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
encoder_attn
(
decoder_hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
fc1_out
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
fc1_out
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
MBartEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
embed_dim
=
config
.
d_model
self
.
max_source_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
embed_dim
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
embed_dim
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
embed_dim
,
)
self
.
layers
=
nn
.
ModuleList
([
MBartEncoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
config
.
encoder_layers
)
])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
embed_dim
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
)
# 改动
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
embed_pos
=
self
.
embed_positions
(
positions
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
=
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
return
hidden_states
class
MBartDecoder
(
nn
.
Module
):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def
__init__
(
self
,
config
:
BartConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
embed_tokens
:
Optional
[
nn
.
Embedding
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
cache_config
=
cache_config
self
.
quant_config
=
quant_config
self
.
lora_config
=
lora_config
self
.
max_target_positions
=
config
.
max_position_embeddings
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
embed_tokens
=
BartScaledWordEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
if
embed_tokens
is
not
None
:
self
.
embed_tokens
.
weight
=
embed_tokens
.
weight
self
.
embed_positions
=
BartLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
d_model
,
)
self
.
layers
=
nn
.
ModuleList
(
[
MBartDecoderLayer
(
config
,
cache_config
,
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
\
for
layer_idx
in
range
(
config
.
decoder_layers
)])
self
.
layernorm_embedding
=
nn
.
LayerNorm
(
config
.
d_model
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
d_model
)
def
forward
(
self
,
decoder_input_ids
:
torch
.
Tensor
,
decoder_positions
:
torch
.
Tensor
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
decoder_input_ids
)
else
:
decoder_positions
=
inputs_embeds
[:,
-
1
]
# embed positions
embed_pos
=
self
.
embed_positions
(
decoder_positions
)
embed_pos
=
embed_pos
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
embed_pos
hidden_states
=
self
.
layernorm_embedding
(
hidden_states
)
# decoder layers
for
decoder_layer
in
self
.
layers
:
hidden_states
=
decoder_layer
(
decoder_hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
return
hidden_states
class
MBartModel
(
nn
.
Module
,
SupportsQuant
):
_tied_weights_keys
=
[
"encoder.embed_tokens.weight"
,
"decoder.embed_tokens.weight"
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
encoder
=
MBartEncoder
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
decoder
=
MBartDecoder
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states
=
None
if
encoder_input_ids
.
numel
()
>
0
:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
)
return
decoder_outputs
class
MBartForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
,
SupportsQuant
):
base_model_prefix
=
"model"
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"decoder."
:
"model.decoder."
,
"encoder."
:
"model.encoder."
,
"shared."
:
"model.shared."
},
orig_to_new_substr
=
{
"beta"
:
"bias"
,
"gamma"
:
"weight"
,
"LayerNorm"
:
"layernorm"
,
},
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
lora_config
=
vllm_config
.
lora_config
assert
config
.
tie_word_embeddings
self
.
config
=
config
self
.
model
=
MBartModel
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
))
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
lm_head
=
BartParallelLMHead
(
config
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
model_params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
()
remaining_weights
=
[]
shared_embedding_weight
=
None
for
name
,
loaded_weight
in
weights
:
if
any
(
skip
in
name
for
skip
in
[
"cls."
,
"pooler."
,
"final_logits_bias"
]):
continue
if
any
(
embed_name
in
name
for
embed_name
in
[
'shared.weight'
,
'encoder.embed_tokens.weight'
,
'decoder.embed_tokens.weight'
]):
if
shared_embedding_weight
is
None
:
shared_embedding_weight
=
loaded_weight
continue
is_stacked
=
False
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
vllm_name
=
name
for
src
,
dst
in
self
.
hf_to_vllm_mapper
.
orig_to_new_substr
.
items
(
):
vllm_name
=
vllm_name
.
replace
(
src
,
dst
)
for
src
,
dst
in
self
.
hf_to_vllm_mapper
.
orig_to_new_prefix
.
items
(
):
if
vllm_name
.
startswith
(
src
):
vllm_name
=
dst
+
vllm_name
[
len
(
src
):]
break
vllm_name
=
vllm_name
.
replace
(
weight_name
,
param_name
)
if
vllm_name
in
model_params_dict
:
param
=
model_params_dict
[
vllm_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
vllm_name
)
is_stacked
=
True
break
if
not
is_stacked
:
remaining_weights
.
append
((
name
,
loaded_weight
))
loader
=
AutoWeightsLoader
(
self
,
skip_prefixes
=
[
"cls."
,
"pooler."
])
auto_loaded_params
=
loader
.
load_weights
(
remaining_weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
loaded_params
.
update
(
auto_loaded_params
)
if
shared_embedding_weight
is
not
None
:
lm_head_param
=
self
.
lm_head
.
weight
weight_loader
=
getattr
(
lm_head_param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
lm_head_param
,
shared_embedding_weight
)
self
.
model
.
encoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
self
.
model
.
decoder
.
embed_tokens
.
weight
=
self
.
lm_head
.
weight
loaded_params
.
update
({
'model.encoder.embed_tokens.weight'
,
'lm_head.weight'
,
'model.decoder.embed_tokens.weight'
})
return
loaded_params
vllm/model_executor/models/donut.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
from
transformers
import
BatchFeature
,
NougatProcessor
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.bart
import
BartParallelLMHead
,
MBartDecoder
from
vllm.model_executor.models.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsV0Only
)
from
vllm.model_executor.models.swin
import
SwinModel
from
vllm.model_executor.models.utils
import
(
AutoWeightsLoader
,
_flatten_embeddings
,
flatten_bn
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
class
MBartDecoderWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
decoder
=
MBartDecoder
(
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
decoder
(
*
args
,
**
kwargs
)
class
DonutLanguageForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
model
=
MBartDecoderWrapper
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.model"
)
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
BartParallelLMHead
(
self
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
Returns:
Output torch.Tensor
"""
return
self
.
model
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
inputs_embeds
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"final_logits_bias"
in
name
:
continue
# if self.config.tie_word_embeddings and "embed_tokens" in name:
# continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
DonutImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"b"
,
3
,
"h"
,
"w"
)]
class
DonutProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
()
def
get_hf_processor
(
self
):
return
self
.
ctx
.
get_hf_processor
()
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_num_image_tokens
(
self
)
->
int
:
return
1
class
DonutDummyInputsBuilder
(
BaseDummyInputsBuilder
[
DonutProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
self
.
info
.
get_hf_config
(
).
encoder
.
image_size
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
DonutMultiModalProcessor
(
EncDecMultiModalProcessor
[
DonutProcessingInfo
]):
def
_hf_processor_applies_updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
prompt
def
create_decoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
prompt
@
property
def
pad_dummy_encoder_prompt
(
self
)
->
bool
:
return
True
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
hf_processor
=
self
.
info
.
get_hf_processor
()
if
mm_data
:
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
,
tok_kwargs
)
if
isinstance
(
hf_processor
,
NougatProcessor
):
processed_outputs
[
"input_ids"
]
=
processed_outputs
[
"labels"
]
else
:
tokenizer
=
hf_processor
.
tokenizer
processed_outputs
=
tokenizer
(
prompt
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
))
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
=
hf_processor
.
tokenizer
pad_token_id
=
tokenizer
.
pad_token_id
num_image_tokens
=
self
.
info
.
get_num_image_tokens
()
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
return
[
PromptInsertion
(
modality
=
"image"
,
target
=
PromptIndexTargets
.
start
(),
insertion
=
image_tokens
,
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
DonutMultiModalProcessor
,
info
=
DonutProcessingInfo
,
dummy_inputs
=
DonutDummyInputsBuilder
)
class
DonutForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
processor_config
=
vllm_config
.
model_config
.
hf_image_processor_config
self
.
config
=
config
self
.
vision_config
=
config
.
encoder
self
.
processor_config
=
processor_config
self
.
encoder
=
SwinModel
(
config
=
config
.
encoder
)
self
.
decoder
=
DonutLanguageForConditionalGeneration
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
decoder
),
prefix
=
f
"
{
prefix
}
.decoder"
,
)
self
.
pad_token_id
=
config
.
pad_token_id
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
pixel_values
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
h
,
w
=
self
.
config
.
encoder
.
image_size
return
DonutImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
resolve_bindings
=
{
"h"
:
h
,
"w"
:
w
,
})
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_image_input
(
self
,
image_input
:
DonutImagePixelInputs
)
->
torch
.
Tensor
:
assert
image_input
[
"type"
]
==
"pixel_values"
pixel_values
=
image_input
[
"data"
]
dtype
=
next
(
self
.
encoder
.
parameters
()).
dtype
pixel_values
=
pixel_values
.
to
(
dtype
)
return
self
.
encoder
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
decoder
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
MultiModalEmbeddings
,
)
->
torch
.
Tensor
:
return
_flatten_embeddings
(
multimodal_embeddings
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
inputs_embeds
=
None
if
encoder_input_ids
.
numel
()
>
0
:
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
encoder_input_ids
,
vision_embeddings
)
hidden_states
=
self
.
decoder
(
input_ids
,
positions
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
decoder
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/florence2.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
collections
import
OrderedDict
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
BartTokenizer
,
BatchFeature
,
PretrainedConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.bart
import
(
BartDecoder
,
BartEncoder
,
BartParallelLMHead
,
BartScaledWordEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargsItems
)
from
vllm.multimodal.parse
import
MultiModalDataItems
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptIndexTargets
,
PromptInsertion
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsV0Only
)
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
merge_multimodal_embeddings
class
Florence2ImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height of the image
- w: Width of the image
"""
type
:
Literal
[
"pixel_values"
]
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"b"
,
3
,
"h"
,
"w"
),
]
# ViT implementation are all copied from
# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
class
LearnedAbsolutePositionEmbedding2D
(
nn
.
Module
):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def
__init__
(
self
,
embedding_dim
=
256
,
num_pos
=
50
):
super
().
__init__
()
self
.
row_embeddings
=
nn
.
Embedding
(
num_pos
,
embedding_dim
//
2
)
self
.
column_embeddings
=
nn
.
Embedding
(
num_pos
,
embedding_dim
-
(
embedding_dim
//
2
))
def
forward
(
self
,
pixel_values
):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if
len
(
pixel_values
.
shape
)
!=
4
:
raise
ValueError
(
'pixel_values must be a 4D tensor'
)
height
,
width
=
pixel_values
.
shape
[
1
:
3
]
width_values
=
torch
.
arange
(
width
,
device
=
pixel_values
.
device
)
height_values
=
torch
.
arange
(
height
,
device
=
pixel_values
.
device
)
x_emb
=
self
.
column_embeddings
(
width_values
)
y_emb
=
self
.
row_embeddings
(
height_values
)
# (height, width, embedding_dim * 2)
pos
=
torch
.
cat
([
x_emb
.
unsqueeze
(
0
).
repeat
(
height
,
1
,
1
),
y_emb
.
unsqueeze
(
1
).
repeat
(
1
,
width
,
1
)
],
dim
=-
1
)
# (embedding_dim * 2, height, width)
pos
=
pos
.
permute
(
2
,
0
,
1
)
pos
=
pos
.
unsqueeze
(
0
)
# (batch_size, embedding_dim * 2, height, width)
pos
=
pos
.
repeat
(
pixel_values
.
shape
[
0
],
1
,
1
,
1
)
# (batch_size, height, width, embedding_dim * 2)
pos
=
pos
.
permute
(
0
,
2
,
3
,
1
)
return
pos
class
PositionalEmbeddingCosine1D
(
nn
.
Module
):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def
__init__
(
self
,
embed_dim
:
int
=
512
,
max_seq_len
:
int
=
1024
)
->
None
:
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
max_seq_len
=
max_seq_len
# Generate the sinusoidal arrays.
factor
=
math
.
log
(
10000
)
denominator
=
torch
.
exp
(
-
factor
*
torch
.
arange
(
0
,
self
.
embed_dim
,
2
)
/
self
.
embed_dim
)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies
=
\
torch
.
arange
(
0
,
self
.
max_seq_len
)
\
.
reshape
(
self
.
max_seq_len
,
1
)
*
denominator
pos_idx_to_embed
=
torch
.
zeros
((
self
.
max_seq_len
,
self
.
embed_dim
))
# Populate uneven entries.
pos_idx_to_embed
[:,
0
::
2
]
=
torch
.
sin
(
frequencies
)
pos_idx_to_embed
[:,
1
::
2
]
=
torch
.
cos
(
frequencies
)
# Save the positional embeddings in a constant buffer.
# self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
self
.
pos_idx_to_embed
=
nn
.
Parameter
(
pos_idx_to_embed
,
requires_grad
=
False
)
def
forward
(
self
,
seq_embeds
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len
=
len
(
seq_embeds
.
shape
)
assert
2
<=
shape_len
<=
3
len_seq
=
seq_embeds
.
size
(
-
2
)
assert
len_seq
<=
self
.
max_seq_len
pos_embeds
=
self
.
pos_idx_to_embed
[
0
:
seq_embeds
.
size
(
-
2
),
:]
# Adapt pre-computed positional embeddings to the input.
if
shape_len
==
3
:
pos_embeds
=
pos_embeds
.
view
(
(
1
,
pos_embeds
.
size
(
0
),
pos_embeds
.
size
(
1
)))
return
pos_embeds
class
MySequential
(
nn
.
Sequential
):
def
forward
(
self
,
*
inputs
):
for
module
in
self
.
_modules
.
values
():
if
isinstance
(
inputs
,
tuple
):
inputs
=
module
(
*
inputs
)
else
:
inputs
=
module
(
inputs
)
return
inputs
class
PreNorm
(
nn
.
Module
):
def
__init__
(
self
,
norm
,
fn
):
super
().
__init__
()
self
.
norm
=
norm
self
.
fn
=
fn
def
forward
(
self
,
x
,
*
args
,
**
kwargs
):
shortcut
=
x
if
self
.
norm
is
not
None
:
x
,
size
=
self
.
fn
(
self
.
norm
(
x
),
*
args
,
**
kwargs
)
else
:
x
,
size
=
self
.
fn
(
x
,
*
args
,
**
kwargs
)
x
=
shortcut
+
x
return
x
,
size
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
net
=
nn
.
Sequential
(
OrderedDict
([(
"fc1"
,
nn
.
Linear
(
in_features
,
hidden_features
)),
(
"act"
,
act_layer
()),
(
"fc2"
,
nn
.
Linear
(
hidden_features
,
out_features
))]))
def
forward
(
self
,
x
,
size
):
return
self
.
net
(
x
),
size
class
DepthWiseConv2d
(
nn
.
Module
):
def
__init__
(
self
,
dim_in
,
kernel_size
,
padding
,
stride
,
bias
=
True
,
):
super
().
__init__
()
self
.
dw
=
nn
.
Conv2d
(
dim_in
,
dim_in
,
kernel_size
=
kernel_size
,
padding
=
padding
,
groups
=
dim_in
,
stride
=
stride
,
bias
=
bias
)
def
forward
(
self
,
x
,
size
):
B
,
N
,
C
=
x
.
shape
H
,
W
=
size
assert
N
==
H
*
W
x
=
self
.
dw
(
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
))
size
=
(
x
.
size
(
-
2
),
x
.
size
(
-
1
))
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
,
size
class
ConvEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
patch_size
=
7
,
in_chans
=
3
,
embed_dim
=
64
,
stride
=
4
,
padding
=
2
,
norm_layer
=
None
,
pre_norm
=
True
):
super
().
__init__
()
self
.
patch_size
=
patch_size
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
padding
)
dim_norm
=
in_chans
if
pre_norm
else
embed_dim
self
.
norm
=
norm_layer
(
dim_norm
)
if
norm_layer
else
None
self
.
pre_norm
=
pre_norm
def
forward
(
self
,
x
,
size
):
H
,
W
=
size
if
len
(
x
.
size
())
==
3
:
if
self
.
norm
and
self
.
pre_norm
:
x
=
self
.
norm
(
x
)
x
=
rearrange
(
x
,
'b (h w) c -> b c h w'
,
h
=
H
,
w
=
W
)
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
rearrange
(
x
,
'b c h w -> b (h w) c'
)
if
self
.
norm
and
not
self
.
pre_norm
:
x
=
self
.
norm
(
x
)
return
x
,
(
H
,
W
)
class
ChannelAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
groups
=
8
,
qkv_bias
=
True
):
super
().
__init__
()
self
.
groups
=
groups
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
def
forward
(
self
,
x
,
size
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
groups
,
C
//
self
.
groups
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
(
float
(
N
)
**-
0.5
)
attention
=
q
.
transpose
(
-
1
,
-
2
)
@
k
attention
=
attention
.
softmax
(
dim
=-
1
)
x
=
(
attention
@
v
.
transpose
(
-
1
,
-
2
)).
transpose
(
-
1
,
-
2
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
return
x
,
size
class
ChannelBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
groups
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
):
super
().
__init__
()
self
.
conv1
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_attn
else
None
self
.
channel_attn
=
PreNorm
(
norm_layer
(
dim
),
ChannelAttention
(
dim
,
groups
=
groups
,
qkv_bias
=
qkv_bias
),
)
self
.
conv2
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_ffn
else
None
self
.
ffn
=
PreNorm
(
norm_layer
(
dim
),
Mlp
(
in_features
=
dim
,
hidden_features
=
int
(
dim
*
mlp_ratio
),
act_layer
=
act_layer
),
)
def
forward
(
self
,
x
,
size
):
if
self
.
conv1
:
x
,
size
=
self
.
conv1
(
x
,
size
)
x
,
size
=
self
.
channel_attn
(
x
,
size
)
if
self
.
conv2
:
x
,
size
=
self
.
conv2
(
x
,
size
)
x
,
size
=
self
.
ffn
(
x
,
size
)
return
x
,
size
def
window_partition
(
x
,
window_size
:
int
):
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
batch_size
:
int
,
window_size
:
int
,
H
:
int
,
W
:
int
):
B
=
batch_size
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
window_size
,
qkv_bias
=
True
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
float
(
head_dim
)
**-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
size
):
H
,
W
=
size
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
x
=
x
.
view
(
B
,
H
,
W
,
C
)
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
x
=
window_partition
(
x
,
self
.
window_size
)
x
=
x
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
attn
=
self
.
softmax
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
# merge windows
x
=
x
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
x
=
window_reverse
(
x
,
B
,
self
.
window_size
,
Hp
,
Wp
)
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
return
x
,
size
class
SpatialBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
):
super
().
__init__
()
self
.
conv1
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_attn
else
None
self
.
window_attn
=
PreNorm
(
norm_layer
(
dim
),
WindowAttention
(
dim
,
num_heads
,
window_size
,
qkv_bias
=
qkv_bias
),
)
self
.
conv2
=
PreNorm
(
None
,
DepthWiseConv2d
(
dim
,
3
,
1
,
1
))
if
conv_at_ffn
else
None
self
.
ffn
=
PreNorm
(
norm_layer
(
dim
),
Mlp
(
in_features
=
dim
,
hidden_features
=
int
(
dim
*
mlp_ratio
),
act_layer
=
act_layer
),
)
def
forward
(
self
,
x
,
size
):
if
self
.
conv1
:
x
,
size
=
self
.
conv1
(
x
,
size
)
x
,
size
=
self
.
window_attn
(
x
,
size
)
if
self
.
conv2
:
x
,
size
=
self
.
conv2
(
x
,
size
)
x
,
size
=
self
.
ffn
(
x
,
size
)
return
x
,
size
class
DaViT
(
nn
.
Module
):
def
__init__
(
self
,
in_chans
=
3
,
num_classes
=
1000
,
depths
=
(
1
,
1
,
3
,
1
),
patch_size
=
(
7
,
2
,
2
,
2
),
patch_stride
=
(
4
,
2
,
2
,
2
),
patch_padding
=
(
3
,
0
,
0
,
0
),
patch_prenorm
=
(
False
,
False
,
False
,
False
),
embed_dims
=
(
64
,
128
,
192
,
256
),
num_heads
=
(
3
,
6
,
12
,
24
),
num_groups
=
(
3
,
6
,
12
,
24
),
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
enable_checkpoint
=
False
,
conv_at_attn
=
True
,
conv_at_ffn
=
True
,
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
embed_dims
=
embed_dims
self
.
num_heads
=
num_heads
self
.
num_groups
=
num_groups
self
.
num_stages
=
len
(
self
.
embed_dims
)
self
.
enable_checkpoint
=
enable_checkpoint
assert
self
.
num_stages
==
len
(
self
.
num_heads
)
==
len
(
self
.
num_groups
)
num_stages
=
len
(
embed_dims
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
)
*
2
)
]
depth_offset
=
0
convs
=
[]
blocks
=
[]
for
i
in
range
(
num_stages
):
conv_embed
=
ConvEmbed
(
patch_size
=
patch_size
[
i
],
stride
=
patch_stride
[
i
],
padding
=
patch_padding
[
i
],
in_chans
=
in_chans
if
i
==
0
else
self
.
embed_dims
[
i
-
1
],
embed_dim
=
self
.
embed_dims
[
i
],
norm_layer
=
norm_layer
,
pre_norm
=
patch_prenorm
[
i
])
convs
.
append
(
conv_embed
)
block
=
MySequential
(
*
[
MySequential
(
OrderedDict
([(
'spatial_block'
,
SpatialBlock
(
embed_dims
[
i
],
num_heads
[
i
],
window_size
,
drop_path_rate
=
dpr
[
depth_offset
+
j
*
2
],
qkv_bias
=
qkv_bias
,
mlp_ratio
=
mlp_ratio
,
conv_at_attn
=
conv_at_attn
,
conv_at_ffn
=
conv_at_ffn
,
)),
(
'channel_block'
,
ChannelBlock
(
embed_dims
[
i
],
num_groups
[
i
],
drop_path_rate
=
dpr
[
depth_offset
+
j
*
2
+
1
],
qkv_bias
=
qkv_bias
,
mlp_ratio
=
mlp_ratio
,
conv_at_attn
=
conv_at_attn
,
conv_at_ffn
=
conv_at_ffn
,
))]))
for
j
in
range
(
depths
[
i
])
])
blocks
.
append
(
block
)
depth_offset
+=
depths
[
i
]
*
2
self
.
convs
=
nn
.
ModuleList
(
convs
)
self
.
blocks
=
nn
.
ModuleList
(
blocks
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
@
property
def
dim_out
(
self
):
return
self
.
embed_dims
[
-
1
]
def
forward_features_unpool
(
self
,
x
):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size
=
(
x
.
size
(
2
),
x
.
size
(
3
))
for
conv
,
block
in
zip
(
self
.
convs
,
self
.
blocks
):
x
,
input_size
=
conv
(
x
,
input_size
)
x
,
input_size
=
block
(
x
,
input_size
)
return
x
def
forward_features
(
self
,
x
):
x
=
self
.
forward_features_unpool
(
x
)
# (batch_size, num_tokens, token_dim)
x
=
self
.
avgpool
(
x
.
transpose
(
1
,
2
))
# (batch_size, 1, num_tokens)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
norms
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
@
classmethod
def
from_config
(
cls
,
config
):
return
cls
(
depths
=
config
.
depths
,
embed_dims
=
config
.
dim_embed
,
num_heads
=
config
.
num_heads
,
num_groups
=
config
.
num_groups
,
patch_size
=
config
.
patch_size
,
patch_stride
=
config
.
patch_stride
,
patch_padding
=
config
.
patch_padding
,
patch_prenorm
=
config
.
patch_prenorm
,
drop_path_rate
=
config
.
drop_path_rate
,
window_size
=
config
.
window_size
,
)
# Language backbone and processor implementation
class
Florence2LanguageModel
(
nn
.
Module
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
shared
=
BartScaledWordEmbedding
(
self
.
vocab_size
,
config
.
d_model
)
self
.
encoder
=
BartEncoder
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
)
self
.
decoder
=
BartDecoder
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.decoder"
)
if
self
.
config
.
tie_word_embeddings
:
self
.
encoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
self
.
decoder
.
embed_tokens
.
weight
=
self
.
shared
.
weight
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you
provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states
=
None
if
((
inputs_embeds
is
not
None
and
inputs_embeds
.
numel
()
>
0
)
or
encoder_input_ids
.
numel
()
>
0
):
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states
=
self
.
encoder
(
input_ids
=
encoder_input_ids
,
positions
=
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs
=
self
.
decoder
(
decoder_input_ids
=
input_ids
,
decoder_positions
=
positions
,
encoder_hidden_states
=
encoder_hidden_states
)
return
decoder_outputs
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
self
.
config
=
config
self
.
model
=
Florence2LanguageModel
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.model"
)
embed_scale
=
math
.
sqrt
(
config
.
d_model
)
if
config
.
scale_embedding
else
1.0
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
BartParallelLMHead
(
self
.
vocab_size
,
config
.
d_model
,
embed_scale
=
embed_scale
)
if
self
.
config
.
tie_word_embeddings
:
self
.
lm_head
.
tie_weights
(
self
.
model
.
shared
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
return
self
.
model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
encoder
.
embed_tokens
(
input_ids
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
"final_logits_bias"
in
name
:
continue
if
self
.
config
.
tie_word_embeddings
and
(
"embed_tokens"
in
name
or
"lm_head"
in
name
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
Florence2ProcessingInfo
(
BaseProcessingInfo
):
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
1
}
def
get_num_image_tokens
(
self
)
->
int
:
processor_config
=
self
.
ctx
.
get_hf_image_processor_config
()
return
processor_config
[
"image_seq_length"
]
class
Florence2DummyInputsBuilder
(
BaseDummyInputsBuilder
[
Florence2ProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
return
""
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
=
target_height
=
self
.
info
.
get_hf_config
().
projection_dim
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
Florence2MultiModalProcessor
(
EncDecMultiModalProcessor
[
Florence2ProcessingInfo
]):
def
_hf_processor_applies_updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
prompt
def
create_decoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
return
[
self
.
info
.
get_hf_config
().
eos_token_id
]
def
_apply_hf_processor_tokens_only
(
self
,
prompt_tokens
:
list
[
int
],
)
->
list
[
int
]:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
:
BartTokenizer
=
hf_processor
.
tokenizer
prompt_text
=
tokenizer
.
decode
(
prompt_tokens
)
# convert task tokens to prompt
prompt_text
=
hf_processor
.
_construct_prompts
([
prompt_text
])[
0
]
prompt_tokens
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
)
return
prompt_tokens
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
if
mm_data
:
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
,
tok_kwargs
)
else
:
hf_processor
=
self
.
info
.
get_hf_processor
()
tokenizer
=
hf_processor
.
tokenizer
prompt
=
hf_processor
.
_construct_prompts
([
prompt
])[
0
]
processed_outputs
=
tokenizer
(
prompt
,
add_special_tokens
=
True
,
return_tensors
=
"pt"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
))
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
hf_config
=
self
.
info
.
get_hf_config
()
pad_token_id
=
hf_config
.
pad_token_id
num_image_tokens
=
self
.
info
.
get_num_image_tokens
()
image_tokens
=
[
pad_token_id
]
*
num_image_tokens
return
[
PromptInsertion
(
modality
=
"image"
,
target
=
PromptIndexTargets
.
start
(),
insertion
=
image_tokens
,
)
]
@
MULTIMODAL_REGISTRY
.
register_processor
(
Florence2MultiModalProcessor
,
info
=
Florence2ProcessingInfo
,
dummy_inputs
=
Florence2DummyInputsBuilder
)
class
Florence2ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsV0Only
):
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
if
modality
.
startswith
(
"image"
):
return
None
raise
ValueError
(
"Only image modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
processor_config
=
vllm_config
.
model_config
.
hf_image_processor_config
self
.
config
=
config
self
.
vision_config
=
config
.
vision_config
self
.
processor_config
=
processor_config
assert
config
.
vision_config
.
model_type
==
'davit'
,
(
'only DaViT is supported for now'
)
self
.
vision_tower
=
DaViT
.
from_config
(
config
=
config
.
vision_config
)
self
.
_build_image_projection_layers
(
config
)
self
.
language_model
=
Florence2LanguageForConditionalGeneration
(
vllm_config
=
vllm_config
.
with_hf_config
(
config
.
text_config
),
prefix
=
f
"
{
prefix
}
.language_model"
,
)
self
.
pad_token_id
=
config
.
pad_token_id
def
_build_image_projection_layers
(
self
,
config
:
PretrainedConfig
):
image_dim_out
=
config
.
vision_config
.
dim_embed
[
-
1
]
dim_projection
=
config
.
vision_config
.
projection_dim
self
.
image_projection
=
nn
.
Parameter
(
torch
.
empty
(
image_dim_out
,
dim_projection
))
self
.
image_proj_norm
=
nn
.
LayerNorm
(
dim_projection
)
image_pos_embed_config
=
config
.
vision_config
.
image_pos_embed
if
image_pos_embed_config
[
'type'
]
==
'learned_abs_2d'
:
self
.
image_pos_embed
=
LearnedAbsolutePositionEmbedding2D
(
embedding_dim
=
image_dim_out
,
num_pos
=
image_pos_embed_config
[
'max_pos_embeddings'
])
else
:
raise
NotImplementedError
(
"Florence2 only supports learned_abs_2d "
"as image position embedding."
)
self
.
image_feature_source
=
config
.
vision_config
.
image_feature_source
# temporal embedding
visual_temporal_embedding_config
=
(
self
.
vision_config
.
visual_temporal_embedding
)
if
visual_temporal_embedding_config
[
'type'
]
==
'COSINE'
:
self
.
visual_temporal_embed
=
PositionalEmbeddingCosine1D
(
embed_dim
=
image_dim_out
,
max_seq_len
=
visual_temporal_embedding_config
[
'max_temporal_embeddings'
])
else
:
raise
NotImplementedError
(
'Florence2 only supports COSINE as temporal embedding.'
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
):
pixel_values
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
size
=
self
.
processor_config
[
"size"
]
expected_h
,
expected_w
=
size
[
"height"
],
size
[
"width"
]
return
Florence2ImagePixelInputs
(
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
,
concat
=
True
),
resolve_bindings
=
{
"h"
:
expected_h
,
"w"
:
expected_w
},
)
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
_encode_image
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dtype
=
next
(
self
.
vision_tower
.
parameters
()).
dtype
pixel_values
=
pixel_values
.
to
(
dtype
)
batch_size
,
T
=
pixel_values
.
size
(
0
),
1
x
=
self
.
vision_tower
.
forward_features_unpool
(
pixel_values
)
if
self
.
image_pos_embed
is
not
None
:
x
=
x
.
view
(
batch_size
*
T
,
-
1
,
x
.
shape
[
-
1
])
num_tokens
=
x
.
shape
[
-
2
]
h
,
w
=
int
(
num_tokens
**
0.5
),
int
(
num_tokens
**
0.5
)
assert
h
*
w
==
num_tokens
,
(
'only support square feature maps for now'
)
x
=
x
.
view
(
batch_size
*
T
,
h
,
w
,
x
.
shape
[
-
1
])
pos_embed
=
self
.
image_pos_embed
(
x
)
x
=
x
+
pos_embed
x
=
x
.
view
(
batch_size
,
T
*
h
*
w
,
x
.
shape
[
-
1
])
if
self
.
visual_temporal_embed
is
not
None
:
visual_temporal_embed
=
self
.
visual_temporal_embed
(
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])[:,
:,
0
])
x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])
+
visual_temporal_embed
.
view
(
1
,
T
,
1
,
x
.
shape
[
-
1
])
x_feat_dict
=
{}
spatial_avg_pool_x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
]).
mean
(
dim
=
2
)
x_feat_dict
[
'spatial_avg_pool'
]
=
spatial_avg_pool_x
temporal_avg_pool_x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
]).
mean
(
dim
=
1
)
x_feat_dict
[
'temporal_avg_pool'
]
=
temporal_avg_pool_x
x
=
x
.
view
(
batch_size
,
T
,
-
1
,
x
.
shape
[
-
1
])[:,
-
1
]
x_feat_dict
[
'last_frame'
]
=
x
new_x
=
[]
for
_image_feature_source
in
self
.
image_feature_source
:
if
_image_feature_source
not
in
x_feat_dict
:
raise
ValueError
(
'invalid image feature source: {}'
.
format
(
_image_feature_source
))
new_x
.
append
(
x_feat_dict
[
_image_feature_source
])
x
=
torch
.
cat
(
new_x
,
dim
=
1
)
x
=
x
@
self
.
image_projection
x
=
self
.
image_proj_norm
(
x
)
return
x
def
_process_image_input
(
self
,
image_input
:
Florence2ImagePixelInputs
)
->
torch
.
Tensor
:
assert
image_input
[
"type"
]
==
"pixel_values"
pixel_values
=
image_input
[
"data"
]
return
self
.
_encode_image
(
pixel_values
)
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
MultiModalEmbeddings
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
return
[]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
\
and
len
(
multimodal_embeddings
)
!=
0
:
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
self
.
pad_token_id
)
return
inputs_embeds
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
*
,
encoder_input_ids
:
torch
.
Tensor
,
encoder_positions
:
torch
.
Tensor
,
**
kwargs
,
)
->
torch
.
Tensor
:
r
"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
if
encoder_input_ids
.
numel
()
>
0
or
vision_embeddings
is
not
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
encoder_input_ids
,
vision_embeddings
)
else
:
inputs_embeds
=
None
hidden_states
=
self
.
language_model
(
input_ids
,
positions
,
encoder_input_ids
,
encoder_positions
,
inputs_embeds
=
inputs_embeds
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
return
self
.
language_model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
)
vllm/model_executor/models/mllama.py
deleted
100644 → 0
View file @
5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
typing
import
Annotated
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
import
transformers.models.mllama.configuration_mllama
as
config_mllama
from
PIL.Image
import
Image
from
torch
import
nn
from
transformers
import
BatchFeature
,
MllamaConfig
from
transformers.modeling_outputs
import
(
BaseModelOutput
,
CausalLMOutputWithPast
)
from
transformers.models.mllama.image_processing_mllama
import
(
get_optimal_tiled_canvas
)
from
transformers.models.mllama.processing_mllama
import
(
MllamaProcessor
,
get_cross_attention_token_mask
)
import
vllm.distributed.parallel_state
as
ps
from
vllm.attention
import
Attention
,
AttentionMetadata
,
AttentionType
from
vllm.attention.layer
import
MultiHeadAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.selector
import
_Backend
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVCrossParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalKwargsItems
,
MultiModalUUIDDict
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseProcessingInfo
,
EncDecMultiModalProcessor
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
,
SupportsV0Only
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
logger
=
init_logger
(
__name__
)
class
MllamaImagePixelInputs
(
TensorSchema
):
"""
Dimensions:
- batch_size: Batch size
- max_num_image: Max number of images
- max_num_chunk: Max number of chunks
- max_num_tiles: Max number of tiles per image
- num_channel: Number of channels
- height: Height
- width: Width
"""
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
data
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"batch_size"
,
"max_num_image"
,
"max_num_chunk"
,
"num_channel"
,
"height"
,
"width"
)]
aspect_ratio_ids
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"batch_size"
,
"max_num_image"
)]
aspect_ratio_mask
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"batch_size"
,
"max_num_image"
,
"max_num_tiles"
)]
# TODO: support LlamaImageEmbeddingInputs
def
calc_token_per_chunk
(
image_size
:
int
)
->
int
:
assert
image_size
%
14
==
0
,
"chunk size should be multiple of 14"
token_per_chunk
=
(
image_size
//
14
)
**
2
+
1
return
token_per_chunk
class
MllamaProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
)
->
MllamaConfig
:
return
self
.
ctx
.
get_hf_config
(
MllamaConfig
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
)
->
MllamaProcessor
:
return
self
.
ctx
.
get_hf_processor
(
MllamaProcessor
,
**
kwargs
)
def
get_supported_mm_limits
(
self
)
->
Mapping
[
str
,
Optional
[
int
]]:
return
{
"image"
:
None
}
def
get_token_per_chunk_from_config
(
self
)
->
int
:
image_size
=
self
.
get_hf_config
().
vision_config
.
image_size
return
calc_token_per_chunk
(
image_size
)
def
get_num_tiles_per_image
(
self
,
image_height
:
int
,
image_width
:
int
)
->
int
:
vision_config
=
self
.
get_hf_config
().
vision_config
max_num_tiles
=
vision_config
.
max_num_tiles
image_size
=
vision_config
.
image_size
tiled_height
,
tiled_width
=
get_optimal_tiled_canvas
(
image_height
,
image_width
,
max_num_tiles
,
tile_size
=
image_size
,
)
num_tiles_height
=
tiled_height
//
image_size
num_tiles_width
=
tiled_width
//
image_size
return
num_tiles_height
*
num_tiles_width
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
vision_config
=
self
.
get_hf_config
().
vision_config
image_size
=
vision_config
.
image_size
max_num_tiles
=
vision_config
.
max_num_tiles
# Result in the max possible feature size (h:w = 16:1)
return
ImageSize
(
height
=
max_num_tiles
*
image_size
,
width
=
image_size
)
class
MllamaDummyInputsBuilder
(
BaseDummyInputsBuilder
[
MllamaProcessingInfo
]):
def
get_dummy_text
(
self
,
mm_counts
:
Mapping
[
str
,
int
])
->
str
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
processor
=
self
.
info
.
get_hf_processor
()
image_token
=
processor
.
image_token
return
image_token
*
num_images
def
get_dummy_mm_data
(
self
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
MultiModalDataDict
:
num_images
=
mm_counts
.
get
(
"image"
,
0
)
target_width
,
target_height
=
\
self
.
info
.
get_image_size_with_most_features
()
return
{
"image"
:
self
.
_get_dummy_images
(
width
=
target_width
,
height
=
target_height
,
num_images
=
num_images
)
}
class
MllamaMultiModalProcessor
(
EncDecMultiModalProcessor
[
MllamaProcessingInfo
]
):
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Optional
[
Mapping
[
str
,
object
]]
=
None
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
MultiModalEncDecInputs
:
mm_inputs
=
super
().
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
,
tokenization_kwargs
,
mm_uuids
=
mm_uuids
)
image_token_id
=
self
.
info
.
get_hf_config
().
image_token_index
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens
=
mm_inputs
[
'prompt_token_ids'
].
count
(
image_token_id
)
image_data
=
mm_data
.
get
(
"image"
,
[])
num_images
=
1
if
isinstance
(
image_data
,
Image
)
else
len
(
image_data
)
if
num_image_tokens
!=
num_images
:
raise
ValueError
(
f
"The number of image tokens (
{
num_image_tokens
}
) must be"
f
" the same as the number of images (
{
num_images
}
)"
)
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
# P0 & P1 do cross attention with placeholder of <IMG0>
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
if
mm_data
:
hf_processor
=
self
.
info
.
get_hf_processor
()
image_token
:
str
=
hf_processor
.
image_token
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
token_per_chunk
=
self
.
info
.
get_token_per_chunk_from_config
()
num_decode_images
=
self
.
_get_num_image_in_last_group
(
mm_inputs
[
"prompt_token_ids"
])
num_encode_images
=
num_images
-
num_decode_images
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
num_tiles
=
mm_inputs
[
"mm_kwargs"
].
get_data
()[
"num_tiles"
]
decode_tiles
=
num_tiles
[
num_encode_images
:
num_images
].
sum
().
item
()
num_tokens
=
decode_tiles
*
token_per_chunk
mm_inputs
[
"encoder_prompt_token_ids"
]
=
[
image_token_id
]
*
num_tokens
mm_inputs
[
"encoder_prompt"
]
=
image_token
*
num_tokens
return
mm_inputs
def
_get_num_image_in_last_group
(
self
,
prompt_token_ids
:
list
[
int
])
->
int
:
num_images
=
0
for
token_id
in
prompt_token_ids
[::
-
1
]:
if
token_id
==
self
.
info
.
get_hf_config
().
image_token_index
:
num_images
+=
1
elif
num_images
>
0
:
break
return
num_images
def
_call_hf_processor
(
self
,
prompt
:
str
,
mm_data
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
tokenizer
=
self
.
info
.
get_tokenizer
()
if
mm_data
:
num_tiles
=
[
self
.
info
.
get_num_tiles_per_image
(
img
.
height
,
img
.
width
)
for
img
in
mm_data
[
"images"
]
]
processed_outputs
=
super
().
_call_hf_processor
(
prompt
,
mm_data
,
mm_kwargs
,
tok_kwargs
)
processed_outputs
[
"num_tiles"
]
=
torch
.
tensor
(
num_tiles
)
for
k
in
(
'pixel_values'
,
'aspect_ratio_ids'
,
"aspect_ratio_mask"
):
processed_outputs
[
k
]
=
processed_outputs
[
k
].
squeeze
(
0
)
processed_token_ids
=
processed_outputs
.
pop
(
"input_ids"
)
start_idx
,
end_idx
=
0
,
processed_token_ids
.
size
(
1
)
processed_prompt_text
=
tokenizer
.
decode
(
processed_token_ids
[
0
])
hf_processor
=
self
.
info
.
get_hf_processor
()
bos_token
=
hf_processor
.
bos_token
# Remove the bos_token from the start of prompt,
# because we all know there would be image_token.
if
processed_prompt_text
.
startswith
(
bos_token
):
start_idx
+=
1
# Remove the bos_token from the end of prompt,
# because text is empty in this case.
if
processed_prompt_text
.
endswith
(
bos_token
):
end_idx
-=
1
processed_outputs
[
"input_ids"
]
=
processed_token_ids
[:,
start_idx
:
end_idx
]
else
:
processed_outputs
=
tokenizer
(
prompt
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
return
processed_outputs
def
_get_mm_fields_config
(
self
,
hf_inputs
:
BatchFeature
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratio_ids
=
MultiModalFieldConfig
.
batched
(
"image"
),
aspect_ratio_mask
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_tiles
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
def
create_encoder_prompt
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
)
->
Union
[
str
,
list
[
int
]]:
data
=
mm_data
.
get
(
"image"
,
[])
num_images
=
1
if
isinstance
(
data
,
Image
)
else
len
(
data
)
image_token_id
=
self
.
info
.
get_hf_config
().
image_token_index
return
[
image_token_id
]
*
num_images
def
_get_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
out_mm_kwargs
:
MultiModalKwargsItems
,
)
->
Sequence
[
PromptUpdate
]:
token_per_chunk
=
self
.
info
.
get_token_per_chunk_from_config
()
image_token_id
=
self
.
info
.
get_hf_config
().
image_token_index
def
get_replacement_mllama
(
item_idx
):
images
=
mm_items
.
get_items
(
"image"
,
ImageProcessorItems
)
image_size
=
images
.
get_image_size
(
item_idx
)
num_tile
=
self
.
info
.
get_num_tiles_per_image
(
image_height
=
image_size
.
height
,
image_width
=
image_size
.
width
,
)
num_tokens
=
num_tile
*
token_per_chunk
return
[
image_token_id
]
*
num_tokens
return
[
PromptReplacement
(
modality
=
"image"
,
target
=
[
image_token_id
],
replacement
=
get_replacement_mllama
,
)
]
def
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
:
torch
.
Tensor
,
num_patches
:
int
,
target_length
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
# Expand aspect ratio mask to target_length
batch_size
,
max_num_tiles
=
aspect_ratio_mask
.
shape
attention_mask
=
aspect_ratio_mask
.
view
(
batch_size
,
max_num_tiles
,
1
,
1
).
to
(
dtype
)
attention_mask
=
attention_mask
.
repeat
(
1
,
1
,
target_length
,
1
)
# Mask padding patches
pad_patches
=
target_length
-
num_patches
attention_mask
[:,
:,
-
pad_patches
:]
=
0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask
=
1
-
attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask
=
attention_mask
.
reshape
(
batch_size
,
max_num_tiles
*
target_length
,
1
)
attention_mask
=
attention_mask
@
attention_mask
.
transpose
(
-
1
,
-
2
)
*
torch
.
finfo
(
dtype
).
min
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
return
attention_mask
class
ColumnParallelConv2dPatch
(
torch
.
nn
.
Module
):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
,
kernel_size
:
Union
[
int
,
tuple
[
int
,
int
]],
stride
:
Union
[
int
,
tuple
[
int
,
int
]],
bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
(
kernel_size
,
kernel_size
)
self
.
_unfold
=
torch
.
nn
.
Unfold
(
kernel_size
=
kernel_size
,
stride
=
stride
)
self
.
_linear
=
ColumnParallelLinear
(
in_channels
*
kernel_size
[
0
]
*
kernel_size
[
1
],
out_channels
,
bias
=
bias
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
_unfold
(
x
)
x
=
x
.
permute
(
0
,
2
,
1
)
x
,
_
=
self
.
_linear
(
x
)
return
x
class
MllamaPrecomputedAspectRatioEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
is_gated
:
bool
=
True
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
is_gated
=
is_gated
self
.
embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
hidden_size
)
if
is_gated
:
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embeddings
=
self
.
embedding
(
aspect_ratio_ids
)
embeddings
=
embeddings
.
reshape
(
-
1
,
self
.
max_num_tiles
,
1
,
self
.
hidden_size
)
if
self
.
is_gated
:
embeddings
=
embeddings
*
self
.
gate
.
tanh
()
hidden_state
=
hidden_state
+
embeddings
return
hidden_state
class
MllamaPrecomputedPositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
):
super
().
__init__
()
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
max_aspect_ratio_id
=
config
.
max_aspect_ratio_id
self
.
num_patches
=
(
config
.
image_size
//
config
.
patch_size
)
**
2
+
1
self
.
hidden_size
=
config
.
hidden_size
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
gate
=
nn
.
Parameter
(
torch
.
zeros
(
1
))
# position embedding
position_embedding
=
torch
.
randn
(
self
.
num_patches
,
self
.
hidden_size
)
self
.
embedding
=
nn
.
Parameter
(
self
.
scale
*
position_embedding
)
# tile position embedding
self
.
tile_embedding
=
nn
.
Embedding
(
self
.
max_aspect_ratio_id
+
1
,
self
.
max_num_tiles
*
self
.
num_patches
*
self
.
hidden_size
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# position embeddings
gated_position_embedding
=
(
1
-
self
.
gate
.
tanh
())
*
self
.
embedding
hidden_state
=
hidden_state
+
gated_position_embedding
.
view
(
1
,
1
,
self
.
num_patches
,
self
.
hidden_size
)
# precomputed tile position embeddings
tile_position_embedding
=
self
.
tile_embedding
(
aspect_ratio_ids
)
batch_size
=
hidden_state
.
shape
[
0
]
tile_position_embedding
=
tile_position_embedding
.
reshape
(
batch_size
,
self
.
max_num_tiles
,
self
.
num_patches
,
self
.
hidden_size
)
gated_tile_position_embedding
=
self
.
gate
.
tanh
(
)
*
tile_position_embedding
hidden_state
=
hidden_state
+
gated_tile_position_embedding
return
hidden_state
# TODO: support other attention backends for attention in vision model
class
MllamaVisionSdpaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
tensor_parallel_size
=
get_tp_group
().
world_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
attention_heads
self
.
head_dim
=
config
.
hidden_size
//
config
.
attention_heads
self
.
num_local_heads
=
self
.
num_heads
//
tensor_parallel_size
self
.
q_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
embed_dim
,
self
.
head_dim
,
self
.
num_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
embed_dim
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# Use unified MultiHeadAttention with automatic backend selection
self
.
attn
=
MultiHeadAttention
(
self
.
num_local_heads
,
self
.
head_dim
,
1.0
/
math
.
sqrt
(
self
.
head_dim
))
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_state
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
# Use unified MultiHeadAttention with automatic backend selection
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
reshape
(
attn_output
.
shape
[
0
],
attn_output
.
shape
[
1
],
-
1
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MllamaVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
is_gated
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
attention_heads
self
.
is_gated
=
is_gated
self
.
intermediate_size
=
config
.
intermediate_size
self
.
self_attn
=
MllamaVisionSdpaAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
mlp
=
CLIPMLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
self
.
hidden_size
,
eps
=
config
.
norm_eps
)
# there used to be an if else here, no code path
if
is_gated
:
self
.
gate_attn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
self
.
gate_ffn
=
nn
.
Parameter
(
torch
.
ones
(
1
)
*
math
.
pi
/
4
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
):
# Self Attention
residual
=
hidden_state
hidden_state
=
self
.
input_layernorm
(
hidden_state
)
hidden_state
=
self
.
self_attn
(
hidden_state
,
attention_mask
=
attention_mask
)
gate_attn
=
1
if
not
self
.
is_gated
else
self
.
gate_attn
.
tanh
()
hidden_state
=
residual
+
gate_attn
*
hidden_state
# Feed forward
residual
=
hidden_state
hidden_state
=
self
.
post_attention_layernorm
(
hidden_state
)
hidden_state
=
self
.
mlp
(
hidden_state
)
gate_ffn
=
1
if
not
self
.
is_gated
else
self
.
gate_ffn
.
tanh
()
hidden_state
=
residual
+
gate_ffn
*
hidden_state
return
hidden_state
class
MllamaVisionEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
num_layers
:
int
=
32
,
is_gated
:
bool
=
False
,
output_hidden_states
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
MllamaVisionEncoderLayer
(
config
,
quant_config
=
quant_config
,
is_gated
=
is_gated
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
num_layers
)
])
self
.
output_hidden_states
=
output_hidden_states
or
[]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
BaseModelOutput
]:
encoder_states
=
()
for
i
,
encoder_layer
in
enumerate
(
self
.
layers
):
if
i
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
hidden_states
=
encoder_layer
(
hidden_states
,
attention_mask
,
)
if
len
(
self
.
layers
)
-
1
in
self
.
output_hidden_states
:
encoder_states
=
encoder_states
+
(
hidden_states
,
)
return
hidden_states
,
encoder_states
class
MllamaVisionModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
config_mllama
.
MllamaVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
max_num_tiles
=
config
.
max_num_tiles
self
.
hidden_size
=
config
.
hidden_size
self
.
in_channels
=
config
.
num_channels
self
.
intermediate_layers_indices
=
config
.
intermediate_layers_indices
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
+
1
self
.
scale
=
config
.
hidden_size
**-
0.5
self
.
patch_embedding
=
ColumnParallelConv2dPatch
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
hidden_size
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
class_embedding
=
nn
.
Parameter
(
self
.
scale
*
torch
.
randn
(
self
.
hidden_size
))
self
.
gated_positional_embedding
=
MllamaPrecomputedPositionEmbedding
(
config
)
self
.
pre_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
self
.
post_tile_positional_embedding
=
\
MllamaPrecomputedAspectRatioEmbedding
(
config
,
is_gated
=
True
)
# layer norms
self
.
layernorm_pre
=
nn
.
LayerNorm
(
self
.
hidden_size
)
self
.
layernorm_post
=
nn
.
LayerNorm
(
self
.
hidden_size
)
# encoders
self
.
transformer
=
MllamaVisionEncoder
(
config
,
quant_config
,
config
.
num_hidden_layers
,
is_gated
=
False
,
output_hidden_states
=
config
.
intermediate_layers_indices
,
prefix
=
f
"
{
prefix
}
.transformer"
,
)
self
.
global_transformer
=
MllamaVisionEncoder
(
config
,
quant_config
,
config
.
num_global_layers
,
is_gated
=
True
,
prefix
=
f
"
{
prefix
}
.global_transformer"
,
)
def
apply_class_embedding
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
_
,
hidden_size
=
hidden_state
.
shape
class_embedding
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
hidden_size
)
hidden_state
=
torch
.
cat
([
class_embedding
,
hidden_state
],
dim
=
1
)
return
hidden_state
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
aspect_ratio_ids
:
torch
.
Tensor
,
aspect_ratio_mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
num_concurrent_media
,
num_tiles
,
num_channels
,
\
height
,
width
=
pixel_values
.
shape
pixel_values
=
pixel_values
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_channels
,
height
,
width
)
aspect_ratio_ids
=
aspect_ratio_ids
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
# patch embedding
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
self
.
layernorm_pre
.
weight
.
dtype
))
hidden_state
=
patch_embeds
hidden_state
=
ps
.
get_tp_group
().
all_gather
(
hidden_state
)
# tile embeddings
_
,
num_patches
,
dim
=
hidden_state
.
shape
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
-
1
,
dim
)
hidden_state
=
self
.
pre_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply cls token
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
*
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
apply_class_embedding
(
hidden_state
)
num_patches
+=
1
# apply position embeddings
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
hidden_state
=
self
.
gated_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
# apply encoder
hidden_state
=
self
.
layernorm_pre
(
hidden_state
)
# Compute the number of tokens to pad
num_padding_patches
=
(
8
-
(
hidden_state
.
shape
[
-
2
]
%
8
))
%
8
# Compute padding tuple for pad function
padding
=
(
0
,
0
,
0
,
num_padding_patches
)
# (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state
=
F
.
pad
(
hidden_state
,
padding
,
mode
=
"constant"
,
value
=
0
)
slice_index
=
-
num_padding_patches
if
num_padding_patches
>
0
else
None
attention_mask
=
aspect_ratio_mask
.
reshape
(
batch_size
*
num_concurrent_media
,
-
1
)
attention_mask
=
_prepare_aspect_ratio_attention_mask
(
aspect_ratio_mask
=
attention_mask
,
num_patches
=
self
.
num_patches
,
target_length
=
hidden_state
.
shape
[
2
],
dtype
=
self
.
layernorm_pre
.
weight
.
dtype
,
)
hidden_state
=
hidden_state
.
view
(
batch_size
*
num_concurrent_media
,
-
1
,
dim
)
output
=
self
.
transformer
(
hidden_state
,
attention_mask
=
attention_mask
,
)
hidden_state
,
intermediate_hidden_states
=
output
[
0
],
output
[
1
]
intermediate_hidden_states
=
torch
.
stack
(
intermediate_hidden_states
,
dim
=-
1
)
# apply global encoder
hidden_state
=
self
.
layernorm_post
(
hidden_state
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
self
.
post_tile_positional_embedding
(
hidden_state
,
aspect_ratio_ids
)
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
*
(
num_patches
+
num_padding_patches
),
dim
)
hidden_state
=
self
.
global_transformer
(
hidden_state
,
attention_mask
=
attention_mask
)[
0
]
hidden_state
=
hidden_state
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
dim
)
hidden_state
=
hidden_state
[:,
:,
:
slice_index
]
# adding intermediate layer outputs
hidden_state
=
hidden_state
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
dim
)
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
*
num_concurrent_media
,
num_tiles
,
num_patches
+
num_padding_patches
,
-
1
)
intermediate_hidden_states
=
intermediate_hidden_states
[:,
:,
:
slice_index
]
intermediate_hidden_states
=
intermediate_hidden_states
.
reshape
(
batch_size
,
num_concurrent_media
,
num_tiles
,
num_patches
,
-
1
)
hidden_state
=
torch
.
cat
([
hidden_state
,
intermediate_hidden_states
],
dim
=-
1
)
return
hidden_state
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
'patch_embedding._linear.weight'
in
name
:
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
.
shape
[
0
],
-
1
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
.
pop
(
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
return
updated_params
class
MllamaTextRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
extra_repr
(
self
):
return
f
"
{
tuple
(
self
.
weight
.
shape
)
}
, eps=
{
self
.
variance_epsilon
}
"
class
MllamaTextCrossAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
Optional
[
config_mllama
.
MllamaTextConfig
]
=
None
,
layer_idx
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
pipeline_parallel_rank
=
get_pp_group
().
rank_in_group
self
.
tensor_parallel_size
=
get_tp_group
().
world_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
num_local_heads
=
self
.
num_heads
//
self
.
tensor_parallel_size
self
.
num_local_key_value_heads
=
\
self
.
num_key_value_heads
//
self
.
tensor_parallel_size
self
.
hidden_size
=
config
.
hidden_size
self
.
head_dim
=
config
.
hidden_size
//
self
.
num_heads
self
.
num_key_value_heads
=
config
.
num_key_value_heads
self
.
layer_idx
=
layer_idx
self
.
num_key_value_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
q_local_size
=
self
.
num_local_heads
*
self
.
head_dim
self
.
kv_local_size
=
self
.
num_local_key_value_heads
*
self
.
head_dim
self
.
qkv_proj
=
QKVCrossParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
num_heads
,
self
.
num_key_value_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self
.
q_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
MllamaTextRMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Attention
(
self
.
num_local_heads
,
self
.
head_dim
,
self
.
scaling
,
self
.
num_local_key_value_heads
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
],
kv_range_for_decode
:
Optional
[
list
[
tuple
[
int
,
int
]]],
cross_attention_states
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
q
,
k
,
v
=
self
.
qkv_proj
(
hidden_states
,
cross_attention_states
)
if
cross_attention_states
is
not
None
:
k
=
k
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
v
=
v
.
view
(
-
1
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
k
=
self
.
k_norm
(
k
)
q
=
q
.
view
(
-
1
,
self
.
num_local_heads
,
self
.
head_dim
)
q
=
self
.
q_norm
(
q
)
if
attention_mask
is
not
None
:
output
=
self
.
_attention_with_mask
(
q
,
k
,
v
,
attention_mask
,
kv_range_for_decode
)
else
:
output
=
self
.
attn
(
q
.
view
(
-
1
,
self
.
num_local_heads
*
self
.
head_dim
),
k
,
v
)
out
,
_
=
self
.
o_proj
(
output
)
return
out
def
_attention_with_mask
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
list
[
tuple
[
int
,
int
]],
)
->
torch
.
Tensor
:
kv_cache
=
self
.
attn
.
kv_cache
[
self
.
pipeline_parallel_rank
]
attn_metadata
:
AttentionMetadata
=
get_forward_context
().
attn_metadata
# Skip writing kv-cache for the initial profiling run.
# TODO (NickLucche) replace with custom attn bias and use standard attn
if
len
(
kv_cache
.
shape
)
>
1
:
i
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
)
if
self
.
attn
.
backend
in
(
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
):
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
cached_k
,
cached_v
,
kv_cache
[
0
],
kv_cache
[
1
],
attn_metadata
.
cross_slot_mapping
,
# type: ignore[union-attr]
"auto"
,
i
,
i
,
)
elif
self
.
attn
.
backend
in
(
_Backend
.
XFORMERS
,
_Backend
.
ROCM_FLASH
,
_Backend
.
TORCH_SDPA
):
key_cache
,
value_cache
=
PagedAttention
.
split_kv_cache
(
kv_cache
,
self
.
num_local_key_value_heads
,
self
.
head_dim
)
cached_k
=
torch
.
cat
([
k
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
cached_v
=
torch
.
cat
([
v
[
s
:
e
]
for
s
,
e
in
kv_range_for_decode
])
PagedAttention
.
write_to_paged_cache
(
cached_k
,
cached_v
,
key_cache
,
value_cache
,
attn_metadata
.
cross_slot_mapping
,
"auto"
,
i
,
i
)
else
:
raise
ValueError
(
f
"Unsupported Attention backend
{
self
.
attn
.
backend
}
"
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
"XFORMERS or TORCH_SDPA."
)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len
=
q
.
shape
[
0
]
kv_len
=
k
.
shape
[
0
]
q
=
q
.
transpose
(
0
,
1
).
view
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
q_len
,
self
.
head_dim
).
contiguous
()
k
=
k
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
).
contiguous
()
v
=
v
.
transpose
(
0
,
1
)[:,
None
,
:,
:].
expand
(
self
.
num_local_key_value_heads
,
self
.
num_key_value_groups
,
kv_len
,
self
.
head_dim
).
contiguous
()
attention_mask
=
attention_mask
.
view
(
1
,
1
,
q_len
,
kv_len
)
output
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attention_mask
,
is_causal
=
False
)
output
=
output
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
q_len
,
self
.
num_local_heads
*
self
.
head_dim
)
return
output
class
MllamaCrossAttentionDecoderLayer
(
torch
.
nn
.
Module
):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def
__init__
(
self
,
config
:
config_mllama
.
MllamaTextConfig
,
layer_idx
:
int
,
quant_config
:
Optional
[
QuantizationConfig
],
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
layer_idx
=
layer_idx
self
.
cross_attn
=
MllamaTextCrossAttention
(
config
=
config
,
layer_idx
=
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.cross_attn"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_attn_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
self
.
mlp
=
LlamaMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
cross_attn_mlp_gate
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
1
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cross_attention_states
:
torch
.
Tensor
,
cross_attention_mask
:
torch
.
Tensor
,
kv_range_for_decode
:
Optional
[
list
[
tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
cross_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
cross_attention_states
=
cross_attention_states
,
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_attn_gate
.
tanh
(
)
*
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
full_text_row_masked_out_mask
*
hidden_states
hidden_states
=
residual
+
self
.
cross_attn_mlp_gate
.
tanh
(
)
*
hidden_states
return
hidden_states
class
MllamaTextModel
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"model"
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
.
text_config
cache_config
=
vllm_config
.
cache_config
quant_config
=
vllm_config
.
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
+
8
,
config
.
hidden_size
)
self
.
cross_attention_layers
=
config
.
cross_attention_layers
layers
=
[]
for
layer_idx
in
range
(
config
.
num_hidden_layers
):
if
layer_idx
in
self
.
cross_attention_layers
:
layers
.
append
(
MllamaCrossAttentionDecoderLayer
(
config
,
layer_idx
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
))
else
:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers
.
append
(
LlamaDecoderLayer
(
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
))
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
list
[
tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
idx
in
self
.
cross_attention_layers
:
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
hidden_states
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
)
else
:
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
hidden_states
=
hidden_states
,
residual
=
None
,
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
MllamaForCausalLM
(
nn
.
Module
):
config_class
=
config_mllama
.
MllamaTextConfig
base_model_prefix
=
"language_model"
_no_split_modules
=
[
"MllamaCrossAttentionDecoderLayer"
,
"MllamaSelfAttentionDecoderLayer"
]
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
.
text_config
quant_config
=
vllm_config
.
quant_config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
vocab_size
self
.
model
=
MllamaTextModel
(
vllm_config
=
vllm_config
,
prefix
=
f
"
{
prefix
}
.model"
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.lm_head"
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
Optional
[
torch
.
LongTensor
],
cross_attention_states
:
Optional
[
torch
.
LongTensor
],
cross_attention_mask
:
Optional
[
torch
.
LongTensor
],
kv_range_for_decode
:
Optional
[
list
[
tuple
[
int
,
int
]]],
full_text_row_masked_out_mask
:
Optional
[
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
skip_cross_attention
:
bool
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
skip_cross_attention
=
skip_cross_attention
,
)
return
hidden_states
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
'patch_embedding.weight'
in
name
:
name
=
name
.
replace
(
'patch_embedding.weight'
,
'patch_embedding._linear.weight'
)
loaded_weight
=
loaded_weight
.
view
(
loaded_weight
.
shape
[
0
],
-
1
)
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
loaded_weight
=
(
loaded_weight
if
loaded_weight
.
dim
()
==
0
else
loaded_weight
[
0
])
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
scale_name
)
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
updated_params
.
add
(
name
)
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
orig_name
=
name
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
logger
.
debug
(
"Missing name %s, orig name %s"
,
name
,
orig_name
)
continue
param
=
params_dict
.
pop
(
name
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
return
updated_params
@
MULTIMODAL_REGISTRY
.
register_processor
(
MllamaMultiModalProcessor
,
info
=
MllamaProcessingInfo
,
dummy_inputs
=
MllamaDummyInputsBuilder
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
# mapping for new names in checkpoint saved after transformers v4.52
"model.vision_model."
:
"vision_model."
,
"model.multi_modal_projector."
:
"multi_modal_projector."
,
"model.language_model."
:
"language_model.model."
,
"lm_head."
:
"language_model.lm_head."
,
},
orig_to_new_suffix
=
{
"patch_embedding.weight"
:
"patch_embedding._linear.weight"
,
},
)
@
classmethod
def
get_placeholder_str
(
cls
,
modality
:
str
,
i
:
int
)
->
Optional
[
str
]:
if
modality
.
startswith
(
"image"
):
return
"<|image|>"
raise
ValueError
(
"Only image modality is supported"
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
config
:
MllamaConfig
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
config
.
text_config
.
vocab_size
self
.
hidden_size
=
config
.
text_config
.
hidden_size
self
.
max_num_tiles
=
config
.
vision_config
.
max_num_tiles
self
.
vision_output_dim
=
config
.
vision_config
.
vision_output_dim
self
.
pad_token_id
=
\
config
.
pad_token_id
if
config
.
pad_token_id
is
not
None
else
-
1
self
.
image_size
=
config
.
vision_config
.
image_size
self
.
image_token_id
=
config
.
image_token_index
self
.
vision_model
=
MllamaVisionModel
(
config
.
vision_config
,
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"vision_model"
))
self
.
language_model
=
MllamaForCausalLM
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"language_model"
),
)
self
.
multi_modal_projector
=
ColumnParallelLinear
(
config
.
vision_config
.
vision_output_dim
,
config
.
text_config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
gather_output
=
True
,
prefix
=
maybe_prefix
(
prefix
,
"multi_modal_projector"
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
output_hidden_states
,
config
.
text_config
.
vocab_size
)
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
torch
.
Tensor
]:
logits
=
self
.
logits_processor
(
self
.
language_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
unpack_data
(
self
,
image_data
:
Union
[
list
[
torch
.
Tensor
],
torch
.
Tensor
],
padding_value
=
0
)
->
torch
.
Tensor
:
if
isinstance
(
image_data
,
torch
.
Tensor
):
# torch.Tensor
return
image_data
else
:
assert
isinstance
(
image_data
[
0
],
torch
.
Tensor
),
"Image data is not properly batched."
# list[torch.Tensor]
bsz
=
len
(
image_data
)
max_length
=
max
(
t
.
size
(
0
)
for
t
in
image_data
)
trailing_dims
=
image_data
[
0
].
shape
[
1
:]
for
data
in
image_data
:
cur_trailing_dims
=
data
.
shape
[
1
:]
assert
cur_trailing_dims
==
trailing_dims
output_tensor
=
torch
.
full
((
bsz
,
max_length
,
*
trailing_dims
),
padding_value
,
dtype
=
image_data
[
0
].
dtype
,
device
=
image_data
[
0
].
device
)
for
i
,
t
in
enumerate
(
image_data
):
output_tensor
[
i
,
:
t
.
size
(
0
)]
=
t
return
output_tensor
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
MllamaImagePixelInputs
]:
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - list[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"pixel_values"
,
None
)
image_embeds
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"image_embeds"
,
None
)
aspect_ratio_ids
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_ids"
,
None
)
aspect_ratio_mask
:
Optional
[
Union
[
list
[
list
[
torch
.
Tensor
]],
list
[
torch
.
Tensor
],
torch
.
Tensor
]]
=
kwargs
.
pop
(
"aspect_ratio_mask"
,
None
)
if
pixel_values
is
None
and
image_embeds
is
None
:
return
None
if
pixel_values
is
not
None
and
image_embeds
is
not
None
:
raise
ValueError
(
"Both pixel values and image embeds are provided."
)
if
pixel_values
is
not
None
:
assert
aspect_ratio_ids
is
not
None
assert
aspect_ratio_mask
is
not
None
return
MllamaImagePixelInputs
(
type
=
"pixel_values"
,
data
=
self
.
unpack_data
(
pixel_values
),
aspect_ratio_ids
=
self
.
unpack_data
(
aspect_ratio_ids
),
aspect_ratio_mask
=
self
.
unpack_data
(
aspect_ratio_mask
))
if
image_embeds
is
not
None
:
raise
NotImplementedError
raise
AssertionError
(
"This line should be unreachable."
)
def
_get_and_validate_encoder_lens
(
self
,
encoder_seq_lens
:
list
[
int
],
num_tiles
:
list
[
list
[
int
]],
num_tokens_per_tile
:
int
,
)
->
list
[
int
]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens
=
[
sum
(
num_tile
)
*
num_tokens_per_tile
for
num_tile
in
num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens
=
[
x
for
x
in
encoder_seq_lens
if
x
>
0
]
assert
len
(
actual_encoder_seq_lens
)
==
len
(
attn_metadata_lens
)
for
actual_len
,
last_group_len
in
zip
(
actual_encoder_seq_lens
,
attn_metadata_lens
):
assert
actual_len
>=
last_group_len
return
actual_encoder_seq_lens
def
flat_encoder_result
(
self
,
cross_attention_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
list
[
int
]):
cross_attention_states_flat
=
torch
.
zeros
(
sum
(
actual_encoder_seq_lens
),
cross_attention_states
.
shape
[
-
1
],
device
=
cross_attention_states
.
device
,
dtype
=
cross_attention_states
.
dtype
)
start_pos
=
0
for
seq_len
,
vision_token_in_batch
in
zip
(
actual_encoder_seq_lens
,
cross_attention_states
):
end_pos
=
start_pos
+
seq_len
cross_attention_states_flat
[
start_pos
:
end_pos
]
=
vision_token_in_batch
[:
seq_len
]
start_pos
=
end_pos
cross_attention_states
=
cross_attention_states_flat
return
cross_attention_states
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
def
get_cross_attention_states
(
self
,
image_inputs
:
MllamaImagePixelInputs
,
attn_metadata
:
AttentionMetadata
,
actual_encoder_seq_lens
:
list
[
int
],
)
->
tuple
[
torch
.
Tensor
]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values
=
image_inputs
[
'data'
]
aspect_ratio_ids
=
image_inputs
[
'aspect_ratio_ids'
]
aspect_ratio_mask
=
image_inputs
[
'aspect_ratio_mask'
]
cross_attention_states
=
self
.
vision_model
(
pixel_values
,
aspect_ratio_ids
,
aspect_ratio_mask
)
cross_attention_states
,
_
=
self
.
multi_modal_projector
(
cross_attention_states
)
bsz
,
_
,
_
,
_
,
image_token_dim
=
tuple
(
cross_attention_states
.
shape
)
cross_attention_states
=
cross_attention_states
.
view
(
bsz
,
-
1
,
image_token_dim
)
cross_attention_states
=
self
.
flat_encoder_result
(
cross_attention_states
,
attn_metadata
,
actual_encoder_seq_lens
)
return
cross_attention_states
def
get_cross_attention_mask
(
self
,
input_ids
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
num_tiles
:
list
[
list
[
int
]],
num_tokens_per_tile
:
int
,
dtype
:
torch
.
dtype
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
token_ids
=
input_ids
.
tolist
()
start
=
0
batch_token_ids
=
[]
for
seq_len
in
attn_metadata
.
seq_lens
:
batch_token_ids
.
append
(
token_ids
[
start
:
start
+
seq_len
])
start
+=
seq_len
sparse_mask
=
[
get_cross_attention_token_mask
(
t
,
self
.
image_token_id
)
for
t
in
batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if
skip_attention_mask
(
sparse_mask
):
return
None
,
None
dense_mask
,
tile_range_for_decode
=
\
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
,
num_tiles
,
attn_metadata
.
seq_lens
)
cross_attention_mask
=
\
convert_dense_cross_attention_mask_to_tensor
(
dense_mask
,
num_tokens_per_tile
,
input_ids
.
device
,
dtype
)
kv_range_for_decode
=
[[
t
[
0
]
*
num_tokens_per_tile
,
t
[
1
]
*
num_tokens_per_tile
]
for
t
in
tile_range_for_decode
]
return
cross_attention_mask
,
kv_range_for_decode
def
get_full_text_row_masked_out_mask
(
self
,
attn_metadata
:
AttentionMetadata
,
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
full_text_row_masked_out_mask
=
torch
.
ones
(
(
attn_metadata
.
num_prefill_tokens
,
1
),
dtype
=
torch
.
bool
)
start_pos
=
0
for
seq_len
,
encoder_seq_len
in
zip
(
attn_metadata
.
seq_lens
,
attn_metadata
.
encoder_seq_lens
):
if
encoder_seq_len
==
0
:
full_text_row_masked_out_mask
[
start_pos
:
start_pos
+
seq_len
]
=
False
start_pos
+=
seq_len
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
.
to
(
device
)
return
full_text_row_masked_out_mask
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
**
kwargs
:
object
,
)
->
Union
[
CausalLMOutputWithPast
]:
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
num_prefill_tokens
>
0
and
\
attn_metadata
.
num_decode_tokens
>
0
:
raise
ValueError
(
"Chunk prefill not supported"
)
image_inputs
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
cross_attention_states
=
None
cross_attention_mask
=
None
kv_range_for_decode
=
None
# For 1) text-only prefill and decode, 2) image-present decode.
if
image_inputs
is
None
:
full_text_row_masked_out_mask
=
(
attn_metadata
.
encoder_seq_lens_tensor
!=
0
).
reshape
(
-
1
,
1
).
to
(
input_ids
.
device
)
skip_cross_attention
=
attn_metadata
.
max_encoder_seq_len
==
0
# For image-present prefill.
else
:
skip_cross_attention
=
False
num_tiles
=
[
t
.
tolist
()
for
t
in
kwargs
.
pop
(
"num_tiles"
)]
num_tokens_per_tile
=
calc_token_per_chunk
(
self
.
image_size
)
actual_encoder_seq_lens
=
self
.
_get_and_validate_encoder_lens
(
attn_metadata
.
encoder_seq_lens
,
num_tiles
,
num_tokens_per_tile
,
)
cross_attention_states
=
self
.
get_cross_attention_states
(
image_inputs
,
attn_metadata
,
actual_encoder_seq_lens
)
full_text_row_masked_out_mask
=
\
self
.
get_full_text_row_masked_out_mask
(
attn_metadata
,
input_ids
.
device
)
cross_attention_mask
,
kv_range_for_decode
=
\
self
.
get_cross_attention_mask
(
input_ids
,
attn_metadata
,
num_tiles
,
num_tokens_per_tile
,
cross_attention_states
.
dtype
)
outputs
=
self
.
language_model
(
input_ids
=
input_ids
,
positions
=
positions
,
cross_attention_states
=
cross_attention_states
,
cross_attention_mask
=
cross_attention_mask
,
kv_range_for_decode
=
kv_range_for_decode
,
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
skip_cross_attention
=
skip_cross_attention
,
)
return
outputs
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
loader
=
AutoWeightsLoader
(
self
)
return
loader
.
load_weights
(
weights
,
mapper
=
self
.
hf_to_vllm_mapper
)
def
get_mm_mapping
(
self
)
->
MultiModelKeys
:
"""
Get the module prefix in multimodal models
"""
return
MultiModelKeys
.
from_string_field
(
language_model
=
"language_model"
,
connector
=
"multi_modal_projector"
,
tower_model
=
"vision_model"
)
def
skip_attention_mask
(
sparse_mask
:
list
[
list
[
int
]])
->
bool
:
for
mask
in
sparse_mask
:
# Skip text-only samples.
if
len
(
mask
)
==
0
:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if
len
(
mask
)
!=
1
:
return
False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if
mask
[
0
][
0
]
!=
0
or
mask
[
0
][
1
]
!=
-
1
:
return
False
return
True
def
convert_sparse_cross_attention_mask_to_dense
(
sparse_mask
:
list
[
list
[
list
[
int
]]],
num_tiles
:
list
[
list
[
int
]],
lengths
:
list
[
int
],
)
->
tuple
[
np
.
ndarray
,
list
[
tuple
[
int
,
int
]]]:
total_length
=
sum
(
lengths
)
total_tiles
=
sum
([
sum
(
tiles
)
for
tiles
in
num_tiles
])
dense_mask
=
np
.
zeros
(
shape
=
(
total_length
,
total_tiles
),
dtype
=
np
.
int64
)
# A list of ranges, range[i] = [start, end] means that the i-th image will
# use tiles[start, end] for cross-attention decoding.
tile_range_for_decode
=
[]
seq_start
=
0
tile_start
=
0
# sparse_mask has an [] entry for each sequence that does not have images,
# but num_tiles does not have these entries...
num_tiles_idx
=
0
for
masks
,
length
in
zip
(
sparse_mask
,
lengths
):
if
len
(
masks
)
==
0
:
# Text only
continue
tiles
=
num_tiles
[
num_tiles_idx
]
num_tiles_idx
+=
1
ts
,
td
=
-
1
,
0
for
mask
,
tile
in
zip
(
masks
,
tiles
):
if
len
(
mask
)
!=
2
:
continue
start
,
end
=
mask
end
=
min
(
end
,
length
)
if
end
==
-
1
:
end
=
length
if
end
==
length
:
if
ts
==
-
1
:
ts
=
tile_start
td
+=
tile
dense_mask
[
seq_start
+
start
:
seq_start
+
end
,
tile_start
:
tile_start
+
tile
]
=
1
tile_start
+=
tile
assert
ts
!=
-
1
assert
td
!=
0
tile_range_for_decode
.
append
((
ts
,
ts
+
td
))
seq_start
+=
length
assert
num_tiles_idx
==
len
(
num_tiles
)
return
dense_mask
,
tile_range_for_decode
def
convert_dense_cross_attention_mask_to_tensor
(
cross_attention_token_mask
:
np
.
ndarray
,
num_tokens_per_tile
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
mask
=
torch
.
tensor
(
cross_attention_token_mask
,
dtype
=
dtype
,
device
=
device
)
mask
=
mask
.
repeat_interleave
(
num_tokens_per_tile
,
dim
=
1
)
mask
=
1.0
-
mask
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
ninf
=
torch
.
finfo
(
dtype
).
min
full_text_mask
=
((
mask
!=
ninf
).
any
(
dim
=-
1
).
type_as
(
mask
)[...,
None
])
mask
*=
full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return
mask
vllm/model_executor/models/registry.py
View file @
759ef49b
...
@@ -147,10 +147,6 @@ _TEXT_GENERATION_MODELS = {
...
@@ -147,10 +147,6 @@ _TEXT_GENERATION_MODELS = {
"TeleFLMForCausalLM"
:
(
"teleflm"
,
"TeleFLMForCausalLM"
),
"TeleFLMForCausalLM"
:
(
"teleflm"
,
"TeleFLMForCausalLM"
),
"XverseForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"XverseForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"Zamba2ForCausalLM"
:
(
"zamba2"
,
"Zamba2ForCausalLM"
),
"Zamba2ForCausalLM"
:
(
"zamba2"
,
"Zamba2ForCausalLM"
),
# [Encoder-decoder]
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartForConditionalGeneration"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"MBartForConditionalGeneration"
:
(
"bart"
,
"MBartForConditionalGeneration"
),
}
}
_EMBEDDING_MODELS
=
{
_EMBEDDING_MODELS
=
{
...
@@ -237,6 +233,7 @@ _MULTIMODAL_MODELS = {
...
@@ -237,6 +233,7 @@ _MULTIMODAL_MODELS = {
"RForConditionalGeneration"
:
(
"rvl"
,
"RForConditionalGeneration"
),
"RForConditionalGeneration"
:
(
"rvl"
,
"RForConditionalGeneration"
),
"KimiVLForConditionalGeneration"
:
(
"kimi_vl"
,
"KimiVLForConditionalGeneration"
),
# noqa: E501
"KimiVLForConditionalGeneration"
:
(
"kimi_vl"
,
"KimiVLForConditionalGeneration"
),
# noqa: E501
"Llama_Nemotron_Nano_VL"
:
(
"nemotron_vl"
,
"LlamaNemotronVLChatModel"
),
"Llama_Nemotron_Nano_VL"
:
(
"nemotron_vl"
,
"LlamaNemotronVLChatModel"
),
"Llama4ForConditionalGeneration"
:
(
"mllama4"
,
"Llama4ForConditionalGeneration"
),
# noqa: E501
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaForConditionalGeneration"
:
(
"llava"
,
"LlavaForConditionalGeneration"
),
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"LlavaNextForConditionalGeneration"
:
(
"llava_next"
,
"LlavaNextForConditionalGeneration"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
(
"llava_next_video"
,
"LlavaNextVideoForConditionalGeneration"
),
# noqa: E501
"LlavaNextVideoForConditionalGeneration"
:
(
"llava_next_video"
,
"LlavaNextVideoForConditionalGeneration"
),
# noqa: E501
...
@@ -263,16 +260,12 @@ _MULTIMODAL_MODELS = {
...
@@ -263,16 +260,12 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniModel"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"Qwen2_5OmniForConditionalGeneration"
:
(
"qwen2_5_omni_thinker"
,
"Qwen2_5OmniThinkerForConditionalGeneration"
),
# noqa: E501
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"Step3VLForConditionalGeneration"
:
(
"step3_vl"
,
"Step3VLForConditionalGeneration"
),
# noqa: E501
"Step3VLForConditionalGeneration"
:
(
"step3_vl"
,
"Step3VLForConditionalGeneration"
),
# noqa: E501
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"TarsierForConditionalGeneration"
:
(
"tarsier"
,
"TarsierForConditionalGeneration"
),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
(
"qwen2_vl"
,
"Tarsier2ForConditionalGeneration"
),
# noqa: E501
"Tarsier2ForConditionalGeneration"
:
(
"qwen2_vl"
,
"Tarsier2ForConditionalGeneration"
),
# noqa: E501
"VoxtralForConditionalGeneration"
:
(
"voxtral"
,
"VoxtralForConditionalGeneration"
),
# noqa: E501
"VoxtralForConditionalGeneration"
:
(
"voxtral"
,
"VoxtralForConditionalGeneration"
),
# noqa: E501
# [Encoder-decoder]
# [Encoder-decoder]
"DonutForConditionalGeneration"
:
(
"donut"
,
"DonutForConditionalGeneration"
),
"Florence2ForConditionalGeneration"
:
(
"florence2"
,
"Florence2ForConditionalGeneration"
),
# noqa: E501
"MllamaForConditionalGeneration"
:
(
"mllama"
,
"MllamaForConditionalGeneration"
),
# noqa: E501
"Llama4ForConditionalGeneration"
:
(
"mllama4"
,
"Llama4ForConditionalGeneration"
),
# noqa: E501
"SkyworkR1VChatModel"
:
(
"skyworkr1v"
,
"SkyworkR1VChatModel"
),
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
"WhisperForConditionalGeneration"
:
(
"whisper"
,
"WhisperForConditionalGeneration"
),
# noqa: E501
}
}
...
...
vllm/multimodal/profiling.py
View file @
759ef49b
...
@@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]):
...
@@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]):
if
processor
.
pad_dummy_encoder_prompt
:
if
processor
.
pad_dummy_encoder_prompt
:
num_tokens_to_pad
=
max
(
total_len
,
seq_len
)
-
total_len
num_tokens_to_pad
=
max
(
total_len
,
seq_len
)
-
total_len
encoder_prompt_token_ids
.
extend
([
0
]
*
num_tokens_to_pad
)
encoder_prompt_token_ids
.
extend
([
0
]
*
num_tokens_to_pad
)
# NOTE: Whisper
and Donut
allows total_len > seq_len.
# NOTE: Whisper allows total_len > seq_len.
elif
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
elif
total_len
>
seq_len
and
not
envs
.
VLLM_USE_V1
:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger
.
warning_once
(
logger
.
warning_once
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment