Unverified Commit 759ef49b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Remove V0 Encoder-Decoder Support (#24907)


Signed-off-by: default avatarWoosuk Kwon <woosuk@thinkingmachines.ai>
parent 5206ab20
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
from PIL import Image
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner
from ...utils import check_logprobs_close
MODELS = ["microsoft/Florence-2-base"]
# Florence-2 model repo's tokenizer config is missing some special tokens.
# Therefore, we use a converted tokenizer from a forked repo
TOKENIZER = "Isotr0py/Florence-2-tokenizer"
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<OD>", # special task token which will output special tokens
"cherry_blossom":
"Describe in detail what is shown in the image.",
})
def get_hf_images_prompts(
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
prompts, images = [], []
for prompt in prompts_:
encoder_prompt = prompt["encoder_prompt"]
prompts.append(
ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt["prompt"],
decoder_prompt=None,
))
images.append(encoder_prompt["multi_modal_data"]["image"])
return prompts, images
def hf_to_vllm_output(hf_output: tuple[list[int], str,
Optional[SampleLogprobs]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids, output_str, out_logprobs = hf_output
output_str = output_str.replace("</s>", "").replace("<s>", "")
return output_ids, output_str, out_logprobs
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[list[ExplicitEncoderDecoderPrompt]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
max_num_seqs=8,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
skip_special_tokens=False,
) for prompts in inputs
]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
hf_outputs_per_case = [
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
for prompts, images in hf_inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
num_outputs_0_skip_tokens=1,
)
# FIXME: https://github.com/huggingface/transformers/issues/38358
@pytest.mark.skip("Model initialization fails")
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
image_assets: ImageTestAssets, model: str,
size_factors: list[int], dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [[
ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=prompt,
multi_modal_data={"image": rescale_image_size(image, factor)}),
decoder_prompt=None,
) for factor in size_factors
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, overload
import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig, AutoModelForImageTextToText, AutoTokenizer
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm import LLM, SamplingParams
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
PromptImageInput, VllmRunner)
from ....quantization.utils import is_quant_method_supported
from ....utils import (create_new_process_for_each_test, large_gpu_test,
multi_gpu_test)
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 3
MLLAMA_IMAGE_TOKEN_ID = 128256
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})
text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]
models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
# Indices for inputs
TEXT_ONLY = '0'
IMAGE_AT_BEG = '1'
IMAGE_AT_MIDDLE = '2'
TWO_IMAGES = '3'
# Input tokenized
prompt_data = {
# Tell me a story
TEXT_ONLY: [41551, 757, 264, 3446],
# <|image|> What's the content of this image
IMAGE_AT_BEG:
[MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220],
# Hello <|image|>What' the content of this image
IMAGE_AT_MIDDLE:
[9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217],
#<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501
TWO_IMAGES: [
MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30,
MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30
]
}
def vllm_to_hf_output(vllm_output: tuple[list[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
hf_output_str = output_str
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
def _get_inputs(
image_assets: ImageTestAssets,
*,
size_factors: Optional[list[float]] = None,
sizes: Optional[list[tuple[int, int]]] = None,
) -> list[tuple[list[str], PromptImageInput]]:
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
return inputs_per_image
@overload
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: ImageTestAssets,
model: str,
*,
size_factors: list[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: ImageTestAssets,
model: str,
*,
sizes: list[tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: ImageTestAssets,
model: str,
*,
size_factors: Optional[list[float]] = None,
sizes: Optional[list[tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
_run_test(
hf_runner,
vllm_runner,
_get_inputs(image_assets, size_factors=size_factors, sizes=sizes),
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
)
def _run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
inputs: list[tuple[list[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(
model,
dtype=dtype,
max_model_len=19212, # 3 max size images
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
with hf_runner(model,
dtype=dtype,
model_kwargs={"device_map": "auto"},
auto_cls=AutoModelForImageTextToText) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.fixture(autouse=True)
def clear_cache():
"""Fixture to clear backend cache before each test."""
_cached_get_attn_backend.cache_clear() # Clear the cache
yield # This allows the test to run
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
reason="Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083")
def test_models_single_leading_image(hf_runner, vllm_runner, image_assets,
model, sizes, dtype, max_tokens,
num_logprobs,
attn_backend: _Backend) -> None:
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
reason="Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083")
def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets,
model, dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501
"<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501
],
[
[stop_sign, cherry_blossom],
# Images with different sizes.
[
stop_sign.resize((512, 512)),
stop_sign,
],
[
stop_sign,
stop_sign.resize((512, 1536)),
cherry_blossom.resize((512, 1024)),
],
])]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
reason="Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083")
def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
dtype, max_tokens, num_logprobs,
attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
cherry_blossom = image_assets[1].pil_image
inputs = [(
[
"<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501
"which is a stop sign and which is a cherry blossom?", # noqa: E501
],
[
[stop_sign],
[stop_sign, cherry_blossom],
])]
with global_force_attn_backend_context_manager(attn_backend):
if attn_backend == _Backend.FLASH_ATTN:
# Flash Attention works only with bfloat16 data-type
dtype = 'bfloat16'
_run_test(
hf_runner,
vllm_runner,
inputs,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@create_new_process_for_each_test()
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.skipif(
Version(TRANSFORMERS_VERSION) <= Version("4.55.2"),
reason="Transformers v4.55 has a regression issue on mllama, "
"see: https://github.com/huggingface/transformers/pull/40083")
def test_models_distributed(
hf_runner,
vllm_runner,
image_assets,
distributed_executor_backend,
model,
dtype,
max_tokens,
num_logprobs,
) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model=model,
size_factors=[0.25, 0.5, 1.0],
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
distributed_executor_backend=distributed_executor_backend,
)
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_bnb_regression(
image_assets: ImageTestAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
prompts = [
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is",
"multi_modal_data": {
"image": stop_sign
},
},
{
"prompt":
"The color of the sky is blue but sometimes it can also be",
},
]
# Test regression about QKVCrossParallelLinear
llm = LLM(
model=model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=2,
quantization="bitsandbytes",
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
assert outputs
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
def test_explicit_implicit_prompt(
image_assets: ImageTestAssets,
model: str,
dtype: str,
max_tokens: int,
):
stop_sign = image_assets[0].pil_image
# yapf: disable
prompts = [
# explicit prompt
{
"encoder_prompt": {
"prompt": "<|image|>",
"multi_modal_data": {"image": stop_sign},
},
"decoder_prompt": {
"prompt_token_ids": [128000, 791, 2262, 315, 279, 2217, 220, 128256, 374], # noqa: E501
}
},
{
"encoder_prompt": "Not <|image|>",
"decoder_prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
# implicit prompt
{
"prompt": "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501
"multi_modal_data": {"image": stop_sign},
},
{
"prompt": "The color of the sky is blue but sometimes it can also be", # noqa: E501
},
]
# yapf: enable
llm = LLM(
model=model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=2,
tensor_parallel_size=1,
)
sampling_params = SamplingParams(
temperature=0,
max_tokens=max_tokens,
)
outputs = llm.generate(prompts, sampling_params)
n_prompts = len(prompts)
explicit_outputs = outputs[:n_prompts // 2]
implicit_outputs = outputs[n_prompts // 2:]
for exp_output, imp_output in zip(explicit_outputs, implicit_outputs):
assert exp_output.outputs[0].text == imp_output.outputs[0].text
@large_gpu_test(min_gb=48)
@pytest.mark.core_model
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
num_logprobs, attn_backend: _Backend) -> None:
stop_sign = image_assets[0].pil_image
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
model,
dtype=dtype,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=1,
limit_mm_per_prompt={"image":
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
# Number of groups of image tokens is greater than the number of images
# provided (the whitespace between the tags is necessary)
prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501
image = stop_sign
with pytest.raises(ValueError):
vllm_model.generate_greedy_logprobs([prompt],
max_tokens,
num_logprobs,
images=[image])
# Batch of a text-only and image request that requires cross-attention
prompts = [
"What is the capital of spain?",
"Text before the image...<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
# Test the reverse order too for good measure
prompts = [
"<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Hello!",
]
images = [
[stop_sign],
None,
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
# Mixed batch with text and images with different numbers of tiles
prompts = [
"<|begin_of_text|>Hello!",
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
"<|begin_of_text|>Some text before.<|image|>What is in the image?", # noqa: E501
]
images = [
None,
[stop_sign],
# smaller image must be 2nd for the repro
[stop_sign.resize((448, 448))],
]
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs,
images=images)
class DummyModel:
image_token_id = MLLAMA_IMAGE_TOKEN_ID
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices_and_output",
# inputs, (cross_attention_mask, kv_range_for_decode)
[([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)),
([TEXT_ONLY, IMAGE_AT_BEG], (None, None)),
([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])),
([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])),
([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
((23, 24), [[0, 6], [6, 12]])),
([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])),
([TWO_IMAGES], ((18, 12), [[6, 12]])),
([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))])
def test_get_cross_attention_mask(input_indices_and_output) -> None:
input_indices, expected_output = input_indices_and_output
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices
if i != TEXT_ONLY]
input = torch.cat(sequences)
seq_lens = [len(s) for s in sequences]
attn_data = FlashAttentionMetadata(
seq_lens=seq_lens,
# Dummy values
enable_kv_scales_calculation=False,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=0,
slot_mapping=0,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=0,
max_prefill_seq_len=0,
max_decode_seq_len=0,
context_lens_tensor=None,
block_tables=None,
use_cuda_graph=False,
)
dummy = DummyModel()
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
.get_cross_attention_mask(dummy,
input,
attn_data,
num_tiles=num_tiles,
num_tokens_per_tile=3,
dtype=torch.bfloat16)
expected_cross_attention_mask, expected_kv_range_for_decode = \
expected_output
assert kv_range_for_decode == expected_kv_range_for_decode
if expected_cross_attention_mask is not None:
assert cross_attention_mask is not None
assert cross_attention_mask.shape == expected_cross_attention_mask
else:
assert cross_attention_mask is None
@pytest.mark.core_model
@pytest.mark.parametrize(
"input_indices",
[[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE],
[TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
[IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]])
def test_get_full_text_row_masked_out_mask(input_indices) -> None:
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
seq_lens = [len(s) for s in sequences]
num_prefill_tokens = sum(seq_lens)
# TEXT_ONLY is zero, so it will be masked out,
# other instances should not be.
encoder_seq_lens = [int(i) for i in input_indices]
attn_data = FlashAttentionMetadata(
seq_lens=seq_lens,
encoder_seq_lens=encoder_seq_lens,
num_prefill_tokens=num_prefill_tokens,
# Dummy values
enable_kv_scales_calculation=False,
num_prefills=0,
num_decode_tokens=0,
slot_mapping=0,
multi_modal_placeholder_index_maps=None,
seq_lens_tensor=0,
max_prefill_seq_len=0,
max_decode_seq_len=0,
context_lens_tensor=None,
block_tables=None,
use_cuda_graph=False,
)
dummy = DummyModel()
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
.get_full_text_row_masked_out_mask(dummy,
attn_data,
torch.get_default_device())
full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze()
full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist()
idx = 0
assert len(full_text_row_masked_out_mask) == num_prefill_tokens
for i, seq_len in enumerate(seq_lens):
must_be_masked = input_indices[i] != TEXT_ONLY
for _ in range(seq_len):
assert full_text_row_masked_out_mask[idx] == must_be_masked, \
f"full_text_row_masked_out_mask[{idx}] must be " \
f"'{must_be_masked}' "
idx += 1
@pytest.mark.core_model
@pytest.mark.parametrize("encoder_seq_lens, num_tiles, expected", [
([6404], [[4]], [6404]),
([0, 6404], [[4]], [6404]),
([0, 1601, 8005], [[1], [4, 1]], [1601, 8005]),
([0, 19212, 0, 3202], [[4, 4, 4], [2]], [19212, 3202]),
])
def test_parse_and_validate_encoder_lens(encoder_seq_lens, num_tiles,
expected) -> None:
dummy = DummyModel()
num_tokens_per_tile = 1601
actual_encoder_seq_lens = MllamaForConditionalGeneration \
._get_and_validate_encoder_lens(
dummy,
encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
assert actual_encoder_seq_lens == expected, \
f"Expected {expected} but got {actual_encoder_seq_lens}"
......@@ -167,8 +167,6 @@ def _test_processing_correctness(
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
_ADD_SPECIAL_TOKENS_OVERRIDES = {
"donut": False,
"mllama": False,
"ovis": False,
"ovis2_5": False,
"paligemma": False,
......@@ -278,9 +276,7 @@ def _test_processing_correctness_one(
"facebook/chameleon-7b",
"CohereLabs/command-a-vision-07-2025",
"deepseek-ai/deepseek-vl2-tiny",
"naver-clova-ix/donut-base-finetuned-docvqa",
"baidu/ERNIE-4.5-VL-28B-A3B-PT",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"google/gemma-3-4b-it",
"google/gemma-3n-E2B-it",
......@@ -305,7 +301,6 @@ def _test_processing_correctness_one(
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mispeech/midashenglm-7b",
"openbmb/MiniCPM-Llama3-V-2_5",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for mllama's multimodal preprocessing and profiling."""
import pytest
from transformers import MllamaConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.profiling import MultiModalProfiler
from ...utils import build_model_context
@pytest.mark.parametrize("model_id",
["meta-llama/Llama-3.2-11B-Vision-Instruct"])
@pytest.mark.parametrize("max_model_len", [4096, 8192, 25600, 131072])
@pytest.mark.parametrize("max_num_seqs", [1, 2, 8])
def test_profiling(
model_id: str,
max_model_len: int,
max_num_seqs: int,
):
# regression test for https://github.com/vllm-project/vllm/issues/13929
from vllm.model_executor.models.mllama import calc_token_per_chunk
model_config_kwargs = {
"max_model_len": max_model_len,
}
ctx = build_model_context(
model_id,
model_config_kwargs=model_config_kwargs,
limit_mm_per_prompt={"image": 1},
)
mm_config = ctx.get_mm_config()
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
profiler = MultiModalProfiler(processor)
dummy_encoder_data = profiler.get_encoder_dummy_data(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
dummy_mm_data = processor.dummy_inputs.get_dummy_processor_inputs(
max_model_len,
mm_counts=mm_config.limit_per_prompt,
)
hf_config = ctx.get_hf_config(MllamaConfig)
image_size = hf_config.vision_config.image_size
encoder_seq_lens = [len(dummy_encoder_data.prompt_token_ids)
] * max_num_seqs
mm_data = processor.apply(
prompt=dummy_mm_data.prompt,
mm_data=dummy_mm_data.mm_data,
hf_processor_mm_kwargs=dict(),
)["mm_kwargs"].get_data()
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
num_tiles = [[t] for t in mm_data.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(image_size)
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# simulate mllama image-present prefill.
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
encoder_seq_lens):
assert actual_len >= last_group_len
......@@ -31,7 +31,6 @@ from ...utils import dummy_hf_overrides
ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"Florence2ForConditionalGeneration": "not supported in V1",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",
......
......@@ -354,11 +354,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"MiMoForCausalLM": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True),
"Dots1ForCausalLM": _HfExamplesInfo("rednote-hilab/dots.llm1.inst"),
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
"MBartForConditionalGeneration": _HfExamplesInfo("facebook/mbart-large-en-ro", # noqa: E501
hf_overrides={"architectures": ["MBartForConditionalGeneration"]}), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {
......@@ -496,7 +491,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
trust_remote_code=True),
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
max_model_len=10240,
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
extras={"llama-guard-4": "meta-llama/Llama-Guard-4-12B"}, # noqa: E501
),
"LlavaForConditionalGeneration": _HfExamplesInfo("llava-hf/llava-1.5-7b-hf",
extras={"mistral": "mistral-community/pixtral-12b", # noqa: E501
......@@ -583,15 +578,6 @@ _MULTIMODAL_EXAMPLE_MODELS = {
is_available_online=False,
),
# [Encoder-decoder]
"DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501
hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501
extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="Isotr0py/Florence-2-tokenizer", # noqa: E501
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
# [Cross-encoder]
"JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501
......
......@@ -92,10 +92,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
# has cc==8.9 which hasn't supported FA3 yet. Remove this hack when
# L4 supports FA3.
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
if model_arch == "Florence2ForConditionalGeneration":
# An encoder-decoder model that's V0-only. Just skip it
# since V0 is about to be removed.
pytest.skip("Skipping Florence2ForConditionalGeneration")
if model_arch == "WhisperForConditionalGeneration":
m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
LLM(
......
......@@ -50,7 +50,6 @@ def test_registry_imports(model_arch):
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [
("LlamaForCausalLM", False, False, False),
("MllamaForConditionalGeneration", True, False, False),
("LlavaForConditionalGeneration", True, True, False),
("BertForSequenceClassification", False, False, True),
("RobertaForSequenceClassification", False, False, True),
......
......@@ -299,9 +299,8 @@ def test_rope_customization():
reason="Encoder Decoder models not supported on ROCm.")
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),
("openai/whisper-tiny", True),
("meta-llama/Llama-3.2-1B-Instruct", False),
("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
config = ModelConfig(model_id)
......
......@@ -501,34 +501,6 @@ def test_bind_kv_cache_non_attention():
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]
def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch):
# V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
from vllm.attention import Attention, AttentionType
# example from bart
ctx = {
'encoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
'decoder.layers.0.encoder_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
'decoder.layers.0.self_attn.attn':
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
}
kv_cache = [
torch.zeros((1, )),
]
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache
bind_kv_cache(ctx, [kv_cache])
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
def test_bind_kv_cache_pp():
with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
# this test runs with 1 GPU, but we simulate 2 GPUs
......
......@@ -9,24 +9,9 @@ from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [
"facebook/bart-large-cnn", # encoder decoder
]
MODEL = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.mark.parametrize("model", UNSUPPORTED_MODELS_V1)
def test_reject_unsupported_models(monkeypatch, model):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
args = AsyncEngineArgs(model=model)
with pytest.raises(NotImplementedError):
_ = args.create_engine_config()
m.delenv("VLLM_USE_V1")
def test_reject_bad_config(monkeypatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "0")
......@@ -77,12 +62,6 @@ def test_enable_by_default_fallback(monkeypatch):
assert envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1")
# Should fall back to V0 for supported model.
_ = AsyncEngineArgs(
model=UNSUPPORTED_MODELS_V1[0]).create_engine_config()
assert not envs.VLLM_USE_V1
m.delenv("VLLM_USE_V1")
def test_v1_llm_by_default(monkeypatch):
with monkeypatch.context() as m:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import pytest
import torch
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
BATCH_SIZES = [1, 4, 16, 64, 256]
def _create_model_runner(model: str, *args,
**kwargs) -> EncoderDecoderModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
model_runner = EncoderDecoderModelRunner(
vllm_config=engine_config,
is_driver_worker=True,
)
return model_runner
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output
for empty seq group list"""
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=True,
)
seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(
input_tokens,
input_positions,
encoder_input_tokens,
encoder_input_positions,
attn_metadata,
return_seq_lens,
) = (
model_input.input_tokens,
model_input.input_positions,
model_input.encoder_input_tokens,
model_input.encoder_input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert encoder_input_tokens is None
assert encoder_input_positions is None
assert attn_metadata is None
assert return_seq_lens is None
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_prepare_prompt(batch_size):
'''
Test the ability of the encoder/decoder model runner subclass to
produce prefill-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=True,
)
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for prompts.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == max(seq_lens)
assert attn_metadata.max_decode_seq_len == 0
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs & context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int,
device=device),
)
# Verify block tables are correct for prompts
# - Decoder self-attention
expected = torch.tensor(
[[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Cuda graph should not be used for prefill.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == sum(encoder_seq_lens)
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the prefill phase
expected_selected_token_indices = []
selected_token_start_idx = 0
for seq_len in seq_lens:
# Compute the index offset of the final token in each
# prompt (recall that the prompts are concatenated)
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)
@pytest.mark.skipif(condition=current_platform.is_cpu(),
reason="CPU backend is currently "
"unsupported for encoder/ "
"decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
'''
Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata.
Test behavior:
* Instantiate BART base model & enc/dec model runner
* Construct sequence-group metadata for dummy prompts
* Test that encoder attention, decoder self-attention,
and encoder/decoder cross-attention inputs are correct
Arguments:
* batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
'''
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=True,
)
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
cross_block_table = [2]
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
# Build
# * Decoder model inputs
# * Decoder self-attention KV caching data structures
# * Encoder model inputs
# * Encoder/decoder cross-attention KV caching data structures
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify input metadata is correct for decode phase.
# - Decoder attention metadata
device = model_runner.device
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_decode_tokens > 0
assert torch.equal(attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(seq_lens)
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(encoder_seq_lens)
# Test decoder subquery start locs.
start_idx = 0
start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += 1
start_loc.append(start_idx)
assert torch.equal(
attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device),
)
# Test decoder seq start locs. Note that for normal prefill it is
# equivalent to query_start_loc.
start_idx = 0
seq_start_loc = [start_idx]
for seq_len in seq_lens:
start_idx += seq_len
seq_start_loc.append(start_idx)
# Test seq_start_loc and context lengths
assert torch.equal(
attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device),
)
assert torch.equal(
attn_metadata.context_lens_tensor,
torch.tensor([seq_len - 1 for seq_len in seq_lens],
dtype=torch.int,
device=device))
# Verify block tables are correct for prompts
# - Decoder self-attention
flattened_block_tables = [
block_table for block_table in block_tables.values()
]
expected = torch.tensor(flattened_block_tables *
len(seq_group_metadata_list),
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention
expected = torch.tensor([
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert attn_metadata.use_cuda_graph is False
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == len(seq_lens)
assert len(input_positions) == len(seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == 0
assert len(encoder_input_tokens) == 0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
# Test that vLLM sampling infrastructure chooses the correct
# sequence positions at which to sample (i.e. the end of
# each sequence) in the decode phase
expected_selected_token_indices = []
for selected_token_start_idx, seq_len in enumerate(seq_lens):
# Compute the index offset of the final token in each
# sequence's decoded outputs; since a single token is
# decoded per iteration per sequence, then the length
# of the decoded tokens for a given sequence is 1 and
# the final index offset into a given sequence's
# generated tokens is 0 (i.e. the expected sampling index
# for a given sequence is just `selected_token_start_idx`)
expected_selected_token_indices.append(selected_token_start_idx)
sampling_metadata = model_input.sampling_metadata
actual = sampling_metadata.selected_token_indices
expected = torch.tensor(
expected_selected_token_indices,
device=actual.device,
dtype=actual.dtype,
)
assert torch.equal(actual, expected)
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
@pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
"""
Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded
for varying input batch sizes.
"""
model_runner = _create_model_runner(
"facebook/bart-base",
seed=0,
dtype="float16",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enforce_eager=False,
)
block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
seq_lens: list[int] = []
encoder_seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
cross_block_table = [2]
expanded_batch_size = 0
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(seq_len))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_data = SequenceData.from_seqs(range(encoder_seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
)
assert seq_group_metadata.token_chunk_size == 1
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
expanded_batch_size = expanded_batch_size + len(
seq_group_metadata.seq_data)
seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
encoder_input_tokens = model_input.encoder_input_tokens
encoder_input_positions = model_input.encoder_input_positions
cross_slot_mapping = attn_metadata.cross_slot_mapping
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = model_runner.vllm_config.pad_for_cudagraph(
expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
itertools.repeat(1, cuda_graph_pad_size))
assert return_seq_lens == padded_seq_lens
assert len(slot_mapping) == len(input_tokens)
assert len(cross_slot_mapping) == len(encoder_input_tokens)
# Verify attention metadata
device = model_runner.device
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_decode_tokens > 0
assert torch.equal(
attn_metadata.seq_lens_tensor,
torch.tensor(padded_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == padded_seq_lens
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(seq_lens)
# - Encoder attention metadata
assert attn_metadata.encoder_seq_lens == padded_encoder_seq_lens
assert torch.equal(
attn_metadata.encoder_seq_lens_tensor,
torch.tensor(padded_encoder_seq_lens, device=device, dtype=torch.int))
assert attn_metadata.max_encoder_seq_len == max(padded_encoder_seq_lens)
assert attn_metadata.num_encoder_tokens == sum(padded_encoder_seq_lens)
# Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected.
flattened_block_tables = [
block_table for _ in range(len(seq_group_metadata_list))
for block_table in block_tables.values()
]
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
flattened_block_tables,
max_len=64,
pad=0,
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.block_tables,
expected,
)
# - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected.
expected = [
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
]
expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad(
expected,
max_len=64,
pad=0,
dtype=torch.int32,
device=model_runner.device,
)
assert torch.equal(
attn_metadata.cross_block_tables,
expected,
)
# Model runner's CUDAGraph setting should be propagated to attention
# metadata.
assert attn_metadata.use_cuda_graph is True
# Verify the lengths of input tokens & positions
# - Decoder
assert len(input_tokens) == len(padded_seq_lens)
assert len(input_positions) == len(padded_seq_lens)
# -- An indirect check that model_input.input_tokens
# and model_input.input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
input_tokens,
input_positions,
)
# - Encoder
assert len(encoder_input_tokens) == 0
assert len(encoder_input_tokens) == 0
# -- An indirect check that model_input.encoder_input_tokens
# and model_input.encoder_input_positions are correct -
# by design of the test, the input tokens are
# equal to the input position values, so if
# the model_input data structure has the correct
# values then these two should be equal
assert torch.equal(
encoder_input_tokens,
encoder_input_positions,
)
......@@ -1201,11 +1201,8 @@ class ModelConfig:
getattr(self.hf_config, "max_source_positions", 0))
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
effective_max_seq_len)
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
ROCM_UNSUPPORTED_MODELS = ['mllama']
unsupported_rocm = (self.hf_config.model_type
in ROCM_UNSUPPORTED_MODELS
or self.is_encoder_decoder)
# CUDAGraph capture not supported for encoder-decoder models on ROCm
unsupported_rocm = self.is_encoder_decoder
if (unsupported_rocm and not self.enforce_eager
and current_platform.is_rocm()):
......@@ -1671,10 +1668,6 @@ class ModelConfig:
@property
def is_encoder_decoder(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
"""
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
True to enable cross-attention
"""
return is_encoder_decoder(self.hf_config)
@property
......
......@@ -1789,7 +1789,7 @@ class LLMEngine:
assert isinstance(mm_processor, EncDecMultiModalProcessor)
if mm_processor.pad_dummy_encoder_prompt:
return # Skip encoder length check for Whisper and Donut
return # Skip encoder length check for Whisper
if model_config.is_multimodal_model:
suggestion = (
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import math
from collections.abc import Iterable
from typing import Optional
import torch
from torch import nn
from transformers import BartConfig
from transformers.utils import logging
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsQuant, SupportsV0Only
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
maybe_prefix)
logger = logging.get_logger(__name__)
def get_bsz_seq_len(input_ids):
shp = input_ids.shape
ndim = len(shp)
if ndim == 1:
return 1, input_ids.numel()
else:
return shp[:2]
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self,
positions: torch.Tensor,
) -> torch.Tensor:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
return super().forward(positions + self.offset)
class BartScaledWordEmbedding(VocabParallelEmbedding):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) * self.embed_scale
class BartParallelLMHead(ParallelLMHead):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) / self.embed_scale
class BartEncoderAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class BartDecoderSelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = self.num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.DECODER)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class BartCrossAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
# TP sharding sizes is accounted for within "*Parallel" layers.
self.qkv_proj = QKVCrossParallelLinear(self.d_model,
self.d_model //
self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias,
quant_config=quant_config)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = self.num_heads # No GQA in bart
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER)
def forward(
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
q, k, v = self.qkv_proj(decoder_hidden_states, encoder_hidden_states)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
return output
class BartEncoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartEncoderAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states
class BartDecoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartDecoderSelfAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self.encoder_attn = BartCrossAttention(
self.embed_dim,
config.decoder_attention_heads,
config=config,
prefix=f"{prefix}.encoder_attn",
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual = decoder_hidden_states
# Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([
BartEncoderLayer(config,
cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states=hidden_states)
return hidden_states
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config,
prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for decoder_layer in self.layers:
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
return hidden_states
class BartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [
"encoder.embed_tokens.weight",
"decoder.embed_tokens.weight",
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = BartEncoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states)
return decoder_outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
other_weights = []
loaded_stacked_params = []
model_params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_stacked_params.append(name)
break
else:
if name in model_params_dict:
other_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self)
loaded_params = loader.load_weights(other_weights)
loaded_params.update(loaded_stacked_params)
return loaded_params
class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices.
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
weights_tuple_list = list(weights)
shared_embedding_weight = None
for name, loaded_weight in weights_tuple_list:
if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name
or 'lm_head.weight' in name):
assert shared_embedding_weight is None, (
"Conflicting embedding weights.")
shared_embedding_weight = loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=(["cls.", "pooler."]),
)
loaded_params = loader.load_weights(weights_tuple_list,
mapper=self.hf_to_vllm_mapper)
if shared_embedding_weight is not None:
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, shared_embedding_weight)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
self.model.decoder.embed_tokens.weight = self.lm_head.weight
loaded_params.update({
'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
})
return loaded_params
class MBartEncoderLayer(BartEncoderLayer):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Args:
hidden_states: torch.Tensor of *encoder* input embeddings.
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
hidden_states = cast_overflow_tensors(hidden_states)
return hidden_states
class MBartDecoderLayer(BartDecoderLayer):
def forward(
self,
decoder_hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
residual = decoder_hidden_states
hidden_states = self.self_attn_layer_norm(decoder_hidden_states)
# Self Attention
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MBartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = ""):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList([
MBartEncoderLayer(config,
cache_config,
quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.encoder_layers)
])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layer_norm = nn.LayerNorm(config.d_model) # 改动
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids: Indices of *encoder* input sequence tokens in the
vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *encoder* input sequence tokens.
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states=hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
prefix: str = "",
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[MBartDecoderLayer(config, cache_config, quant_config,
prefix=f"{prefix}.layers.{layer_idx}") \
for layer_idx in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
decoder_positions: Positions of *decoder* input sequence tokens.
encoder_hidden_states: Tensor of encoder output embeddings.
Returns:
Decoder output torch.Tensor
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for decoder_layer in self.layers:
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
class MBartModel(nn.Module, SupportsQuant):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = MBartEncoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = MBartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor) -> torch.Tensor:
r"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states)
return decoder_outputs
class MBartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
base_model_prefix = "model"
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"decoder.": "model.decoder.",
"encoder.": "model.encoder.",
"shared.": "model.shared."
},
orig_to_new_substr={
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
assert config.tie_word_embeddings
self.config = config
self.model = MBartModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
model_params_dict = dict(self.named_parameters())
loaded_params = set()
remaining_weights = []
shared_embedding_weight = None
for name, loaded_weight in weights:
if any(skip in name
for skip in ["cls.", "pooler.", "final_logits_bias"]):
continue
if any(embed_name in name for embed_name in [
'shared.weight', 'encoder.embed_tokens.weight',
'decoder.embed_tokens.weight'
]):
if shared_embedding_weight is None:
shared_embedding_weight = loaded_weight
continue
is_stacked = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
vllm_name = name
for src, dst in self.hf_to_vllm_mapper.orig_to_new_substr.items(
):
vllm_name = vllm_name.replace(src, dst)
for src, dst in self.hf_to_vllm_mapper.orig_to_new_prefix.items(
):
if vllm_name.startswith(src):
vllm_name = dst + vllm_name[len(src):]
break
vllm_name = vllm_name.replace(weight_name, param_name)
if vllm_name in model_params_dict:
param = model_params_dict[vllm_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(vllm_name)
is_stacked = True
break
if not is_stacked:
remaining_weights.append((name, loaded_weight))
loader = AutoWeightsLoader(self, skip_prefixes=["cls.", "pooler."])
auto_loaded_params = loader.load_weights(remaining_weights,
mapper=self.hf_to_vllm_mapper)
loaded_params.update(auto_loaded_params)
if shared_embedding_weight is not None:
lm_head_param = self.lm_head.weight
weight_loader = getattr(lm_head_param, "weight_loader",
default_weight_loader)
weight_loader(lm_head_param, shared_embedding_weight)
self.model.encoder.embed_tokens.weight = self.lm_head.weight
self.model.decoder.embed_tokens.weight = self.lm_head.weight
loaded_params.update({
'model.encoder.embed_tokens.weight', 'lm_head.weight',
'model.decoder.embed_tokens.weight'
})
return loaded_params
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
from transformers import BatchFeature, NougatProcessor
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder
from vllm.model_executor.models.interfaces import (MultiModalEmbeddings,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.swin import SwinModel
from vllm.model_executor.models.utils import (AutoWeightsLoader,
_flatten_embeddings, flatten_bn)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
class MBartDecoderWrapper(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = MBartDecoder(config,
cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = MBartDecoderWrapper(vllm_config=vllm_config,
prefix=f"{prefix}.model")
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
Returns:
Output torch.Tensor
"""
return self.model(decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "final_logits_bias" in name:
continue
# if self.config.tie_word_embeddings and "embed_tokens" in name:
# continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DonutImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"]
data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]
class DonutProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self):
return self.ctx.get_hf_processor()
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_num_image_tokens(self) -> int:
return 1
class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_hf_config(
).encoder.image_size
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]):
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.info.get_hf_processor()
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs)
if isinstance(hf_processor, NougatProcessor):
processed_outputs["input_ids"] = processed_outputs["labels"]
else:
tokenizer = hf_processor.tokenizer
processed_outputs = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
pad_token_id = tokenizer.pad_token_id
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.start(),
insertion=image_tokens,
)
]
@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor,
info=DonutProcessingInfo,
dummy_inputs=DonutDummyInputsBuilder)
class DonutForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
processor_config = vllm_config.model_config.hf_image_processor_config
self.config = config
self.vision_config = config.encoder
self.processor_config = processor_config
self.encoder = SwinModel(config=config.encoder)
self.decoder = DonutLanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.decoder),
prefix=f"{prefix}.decoder",
)
self.pad_token_id = config.pad_token_id
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
h, w = self.config.encoder.image_size
return DonutImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
resolve_bindings={
"h": h,
"w": w,
})
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: DonutImagePixelInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
pixel_values = image_input["data"]
dtype = next(self.encoder.parameters()).dtype
pixel_values = pixel_values.to(dtype)
return self.encoder(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
) -> torch.Tensor:
return _flatten_embeddings(multimodal_embeddings)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
inputs_embeds = None
if encoder_input_ids.numel() > 0:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
vision_embeddings)
hidden_states = self.decoder(input_ids,
positions,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.decoder.compute_logits(hidden_states, sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BartTokenizer, BatchFeature, PretrainedConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
BartParallelLMHead,
BartScaledWordEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptIndexTargets, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsV0Only)
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings
class Florence2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- c: Number of channels (3)
- h: Height of the image
- w: Width of the image
"""
type: Literal["pixel_values"]
data: Annotated[
torch.Tensor,
TensorShape("b", 3, "h", "w"),
]
# ViT implementation are all copied from
# https://huggingface.co/microsoft/Florence-2-base/blob/main/modeling_florence2.py
class LearnedAbsolutePositionEmbedding2D(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, embedding_dim=256, num_pos=50):
super().__init__()
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
self.column_embeddings = nn.Embedding(
num_pos, embedding_dim - (embedding_dim // 2))
def forward(self, pixel_values):
"""
pixel_values: (batch_size, height, width, num_channels)
returns: (batch_size, height, width, embedding_dim * 2)
"""
if len(pixel_values.shape) != 4:
raise ValueError('pixel_values must be a 4D tensor')
height, width = pixel_values.shape[1:3]
width_values = torch.arange(width, device=pixel_values.device)
height_values = torch.arange(height, device=pixel_values.device)
x_emb = self.column_embeddings(width_values)
y_emb = self.row_embeddings(height_values)
# (height, width, embedding_dim * 2)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(height, 1, 1),
y_emb.unsqueeze(1).repeat(1, width, 1)
],
dim=-1)
# (embedding_dim * 2, height, width)
pos = pos.permute(2, 0, 1)
pos = pos.unsqueeze(0)
# (batch_size, embedding_dim * 2, height, width)
pos = pos.repeat(pixel_values.shape[0], 1, 1, 1)
# (batch_size, height, width, embedding_dim * 2)
pos = pos.permute(0, 2, 3, 1)
return pos
class PositionalEmbeddingCosine1D(nn.Module):
"""
This class implements a very simple positional encoding. It follows closely
the encoder from the link below:
https://pytorch.org/tutorials/beginner/translation_transformer.html
Args:
embed_dim: The dimension of the embeddings.
dropout_prob: The dropout probability.
max_seq_len: The maximum length to precompute the positional encodings.
"""
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
super().__init__()
self.embed_dim = embed_dim
self.max_seq_len = max_seq_len
# Generate the sinusoidal arrays.
factor = math.log(10000)
denominator = torch.exp(-factor * torch.arange(0, self.embed_dim, 2) /
self.embed_dim)
# Matrix where rows correspond to a positional embedding as a function
# of the position index (i.e., the row index).
frequencies = \
torch.arange(0, self.max_seq_len) \
.reshape(self.max_seq_len, 1) * denominator
pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim))
# Populate uneven entries.
pos_idx_to_embed[:, 0::2] = torch.sin(frequencies)
pos_idx_to_embed[:, 1::2] = torch.cos(frequencies)
# Save the positional embeddings in a constant buffer.
# self.register_buffer("pos_idx_to_embed", pos_idx_to_embed)
self.pos_idx_to_embed = nn.Parameter(pos_idx_to_embed,
requires_grad=False)
def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor:
"""
Args:
seq_embeds: The sequence embeddings in order. Allowed size:
1. [T, D], where T is the length of the sequence, and D is the
frame embedding dimension.
2. [B, T, D], where B is the batch size and T and D are the
same as above.
Returns a tensor of with the same dimensions as the input: i.e.,
[1, T, D] or [T, D].
"""
shape_len = len(seq_embeds.shape)
assert 2 <= shape_len <= 3
len_seq = seq_embeds.size(-2)
assert len_seq <= self.max_seq_len
pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :]
# Adapt pre-computed positional embeddings to the input.
if shape_len == 3:
pos_embeds = pos_embeds.view(
(1, pos_embeds.size(0), pos_embeds.size(1)))
return pos_embeds
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if isinstance(inputs, tuple):
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class PreNorm(nn.Module):
def __init__(self, norm, fn):
super().__init__()
self.norm = norm
self.fn = fn
def forward(self, x, *args, **kwargs):
shortcut = x
if self.norm is not None:
x, size = self.fn(self.norm(x), *args, **kwargs)
else:
x, size = self.fn(x, *args, **kwargs)
x = shortcut + x
return x, size
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.net = nn.Sequential(
OrderedDict([("fc1", nn.Linear(in_features, hidden_features)),
("act", act_layer()),
("fc2", nn.Linear(hidden_features, out_features))]))
def forward(self, x, size):
return self.net(x), size
class DepthWiseConv2d(nn.Module):
def __init__(
self,
dim_in,
kernel_size,
padding,
stride,
bias=True,
):
super().__init__()
self.dw = nn.Conv2d(dim_in,
dim_in,
kernel_size=kernel_size,
padding=padding,
groups=dim_in,
stride=stride,
bias=bias)
def forward(self, x, size):
B, N, C = x.shape
H, W = size
assert N == H * W
x = self.dw(x.transpose(1, 2).view(B, C, H, W))
size = (x.size(-2), x.size(-1))
x = x.flatten(2).transpose(1, 2)
return x, size
class ConvEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self,
patch_size=7,
in_chans=3,
embed_dim=64,
stride=4,
padding=2,
norm_layer=None,
pre_norm=True):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=padding)
dim_norm = in_chans if pre_norm else embed_dim
self.norm = norm_layer(dim_norm) if norm_layer else None
self.pre_norm = pre_norm
def forward(self, x, size):
H, W = size
if len(x.size()) == 3:
if self.norm and self.pre_norm:
x = self.norm(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
x = self.proj(x)
_, _, H, W = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
if self.norm and not self.pre_norm:
x = self.norm(x)
return x, (H, W)
class ChannelAttention(nn.Module):
def __init__(self, dim, groups=8, qkv_bias=True):
super().__init__()
self.groups = groups
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x, size):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.groups,
C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (float(N)**-0.5)
attention = q.transpose(-1, -2) @ k
attention = attention.softmax(dim=-1)
x = (attention @ v.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x, size
class ChannelBlock(nn.Module):
def __init__(self,
dim,
groups,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True):
super().__init__()
self.conv1 = PreNorm(None, DepthWiseConv2d(
dim, 3, 1, 1)) if conv_at_attn else None
self.channel_attn = PreNorm(
norm_layer(dim),
ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias),
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer),
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.channel_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
def window_partition(x, window_size: int):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int):
B = batch_size
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, num_heads, window_size, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = float(head_dim)**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, size):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x = window_partition(x, self.window_size)
x = x.view(-1, self.window_size * self.window_size, C)
# W-MSA/SW-MSA
# attn_windows = self.attn(x_windows)
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
# merge windows
x = x.view(-1, self.window_size, self.window_size, C)
x = window_reverse(x, B, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
return x, size
class SpatialBlock(nn.Module):
def __init__(self,
dim,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
conv_at_attn=True,
conv_at_ffn=True):
super().__init__()
self.conv1 = PreNorm(None, DepthWiseConv2d(
dim, 3, 1, 1)) if conv_at_attn else None
self.window_attn = PreNorm(
norm_layer(dim),
WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias),
)
self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1,
1)) if conv_at_ffn else None
self.ffn = PreNorm(
norm_layer(dim),
Mlp(in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer),
)
def forward(self, x, size):
if self.conv1:
x, size = self.conv1(x, size)
x, size = self.window_attn(x, size)
if self.conv2:
x, size = self.conv2(x, size)
x, size = self.ffn(x, size)
return x, size
class DaViT(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
depths=(1, 1, 3, 1),
patch_size=(7, 2, 2, 2),
patch_stride=(4, 2, 2, 2),
patch_padding=(3, 0, 0, 0),
patch_prenorm=(False, False, False, False),
embed_dims=(64, 128, 192, 256),
num_heads=(3, 6, 12, 24),
num_groups=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
enable_checkpoint=False,
conv_at_attn=True,
conv_at_ffn=True,
):
super().__init__()
self.num_classes = num_classes
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_groups = num_groups
self.num_stages = len(self.embed_dims)
self.enable_checkpoint = enable_checkpoint
assert self.num_stages == len(self.num_heads) == len(self.num_groups)
num_stages = len(embed_dims)
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate,
sum(depths) * 2)
]
depth_offset = 0
convs = []
blocks = []
for i in range(num_stages):
conv_embed = ConvEmbed(
patch_size=patch_size[i],
stride=patch_stride[i],
padding=patch_padding[i],
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
norm_layer=norm_layer,
pre_norm=patch_prenorm[i])
convs.append(conv_embed)
block = MySequential(*[
MySequential(
OrderedDict([('spatial_block',
SpatialBlock(
embed_dims[i],
num_heads[i],
window_size,
drop_path_rate=dpr[depth_offset + j * 2],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
)),
('channel_block',
ChannelBlock(
embed_dims[i],
num_groups[i],
drop_path_rate=dpr[depth_offset + j * 2 +
1],
qkv_bias=qkv_bias,
mlp_ratio=mlp_ratio,
conv_at_attn=conv_at_attn,
conv_at_ffn=conv_at_ffn,
))])) for j in range(depths[i])
])
blocks.append(block)
depth_offset += depths[i] * 2
self.convs = nn.ModuleList(convs)
self.blocks = nn.ModuleList(blocks)
self.avgpool = nn.AdaptiveAvgPool1d(1)
@property
def dim_out(self):
return self.embed_dims[-1]
def forward_features_unpool(self, x):
"""
forward until avg pooling
Args:
x (_type_): input image tensor
"""
input_size = (x.size(2), x.size(3))
for conv, block in zip(self.convs, self.blocks):
x, input_size = conv(x, input_size)
x, input_size = block(x, input_size)
return x
def forward_features(self, x):
x = self.forward_features_unpool(x)
# (batch_size, num_tokens, token_dim)
x = self.avgpool(x.transpose(1, 2))
# (batch_size, 1, num_tokens)
x = torch.flatten(x, 1)
x = self.norms(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@classmethod
def from_config(cls, config):
return cls(
depths=config.depths,
embed_dims=config.dim_embed,
num_heads=config.num_heads,
num_groups=config.num_groups,
patch_size=config.patch_size,
patch_stride=config.patch_stride,
patch_padding=config.patch_padding,
patch_prenorm=config.patch_prenorm,
drop_path_rate=config.drop_path_rate,
window_size=config.window_size,
)
# Language backbone and processor implementation
class Florence2LanguageModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
self.encoder = BartEncoder(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.encoder")
self.decoder = BartDecoder(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.decoder")
if self.config.tie_word_embeddings:
self.encoder.embed_tokens.weight = self.shared.weight
self.decoder.embed_tokens.weight = self.shared.weight
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids: Indices of *decoder* input sequence tokens
in the vocabulary.
Padding will be ignored by default should you
provide it.
positions: Positions of *decoder* input sequence tokens.
encoder_input_ids: Indices of *encoder* input sequence tokens
in the vocabulary.
encoder_positions: Positions of *encoder* input sequence tokens.
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if ((inputs_embeds is not None and inputs_embeds.numel() > 0)
or encoder_input_ids.numel() > 0):
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
inputs_embeds=inputs_embeds)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states)
return decoder_outputs
class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = Florence2LanguageModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.vocab_size = config.vocab_size
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
if self.config.tie_word_embeddings:
self.lm_head.tie_weights(self.model.shared)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
return self.model(input_ids,
positions,
encoder_input_ids,
encoder_positions,
inputs_embeds=inputs_embeds)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.encoder.embed_tokens(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "final_logits_bias" in name:
continue
if self.config.tie_word_embeddings and ("embed_tokens" in name
or "lm_head" in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Florence2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_num_image_tokens(self) -> int:
processor_config = self.ctx.get_hf_image_processor_config()
return processor_config["image_seq_length"]
class Florence2DummyInputsBuilder(
BaseDummyInputsBuilder[Florence2ProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width = target_height = self.info.get_hf_config().projection_dim
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
) -> bool:
return False
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return prompt
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
return [self.info.get_hf_config().eos_token_id]
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
) -> list[int]:
hf_processor = self.info.get_hf_processor()
tokenizer: BartTokenizer = hf_processor.tokenizer
prompt_text = tokenizer.decode(prompt_tokens)
# convert task tokens to prompt
prompt_text = hf_processor._construct_prompts([prompt_text])[0]
prompt_tokens = tokenizer.encode(prompt_text, add_special_tokens=False)
return prompt_tokens
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs)
else:
hf_processor = self.info.get_hf_processor()
tokenizer = hf_processor.tokenizer
prompt = hf_processor._construct_prompts([prompt])[0]
processed_outputs = tokenizer(prompt,
add_special_tokens=True,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptInsertion(
modality="image",
target=PromptIndexTargets.start(),
insertion=image_tokens,
)
]
@MULTIMODAL_REGISTRY.register_processor(
Florence2MultiModalProcessor,
info=Florence2ProcessingInfo,
dummy_inputs=Florence2DummyInputsBuilder)
class Florence2ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return None
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
processor_config = vllm_config.model_config.hf_image_processor_config
self.config = config
self.vision_config = config.vision_config
self.processor_config = processor_config
assert config.vision_config.model_type == 'davit', (
'only DaViT is supported for now')
self.vision_tower = DaViT.from_config(config=config.vision_config)
self._build_image_projection_layers(config)
self.language_model = Florence2LanguageForConditionalGeneration(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=f"{prefix}.language_model",
)
self.pad_token_id = config.pad_token_id
def _build_image_projection_layers(self, config: PretrainedConfig):
image_dim_out = config.vision_config.dim_embed[-1]
dim_projection = config.vision_config.projection_dim
self.image_projection = nn.Parameter(
torch.empty(image_dim_out, dim_projection))
self.image_proj_norm = nn.LayerNorm(dim_projection)
image_pos_embed_config = config.vision_config.image_pos_embed
if image_pos_embed_config['type'] == 'learned_abs_2d':
self.image_pos_embed = LearnedAbsolutePositionEmbedding2D(
embedding_dim=image_dim_out,
num_pos=image_pos_embed_config['max_pos_embeddings'])
else:
raise NotImplementedError("Florence2 only supports learned_abs_2d "
"as image position embedding.")
self.image_feature_source = config.vision_config.image_feature_source
# temporal embedding
visual_temporal_embedding_config = (
self.vision_config.visual_temporal_embedding)
if visual_temporal_embedding_config['type'] == 'COSINE':
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
embed_dim=image_dim_out,
max_seq_len=visual_temporal_embedding_config[
'max_temporal_embeddings'])
else:
raise NotImplementedError(
'Florence2 only supports COSINE as temporal embedding.')
def _parse_and_validate_image_input(self, **kwargs: object):
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
size = self.processor_config["size"]
expected_h, expected_w = size["height"], size["width"]
return Florence2ImagePixelInputs(
type="pixel_values",
data=flatten_bn(pixel_values, concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w
},
)
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def _encode_image(self, pixel_values: torch.Tensor) -> torch.Tensor:
dtype = next(self.vision_tower.parameters()).dtype
pixel_values = pixel_values.to(dtype)
batch_size, T = pixel_values.size(0), 1
x = self.vision_tower.forward_features_unpool(pixel_values)
if self.image_pos_embed is not None:
x = x.view(batch_size * T, -1, x.shape[-1])
num_tokens = x.shape[-2]
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
assert h * w == num_tokens, (
'only support square feature maps for now')
x = x.view(batch_size * T, h, w, x.shape[-1])
pos_embed = self.image_pos_embed(x)
x = x + pos_embed
x = x.view(batch_size, T * h * w, x.shape[-1])
if self.visual_temporal_embed is not None:
visual_temporal_embed = self.visual_temporal_embed(
x.view(batch_size, T, -1, x.shape[-1])[:, :, 0])
x = x.view(batch_size, T, -1,
x.shape[-1]) + visual_temporal_embed.view(
1, T, 1, x.shape[-1])
x_feat_dict = {}
spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2)
x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x
temporal_avg_pool_x = x.view(batch_size, T, -1,
x.shape[-1]).mean(dim=1)
x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x
x = x.view(batch_size, T, -1, x.shape[-1])[:, -1]
x_feat_dict['last_frame'] = x
new_x = []
for _image_feature_source in self.image_feature_source:
if _image_feature_source not in x_feat_dict:
raise ValueError('invalid image feature source: {}'.format(
_image_feature_source))
new_x.append(x_feat_dict[_image_feature_source])
x = torch.cat(new_x, dim=1)
x = x @ self.image_projection
x = self.image_proj_norm(x)
return x
def _process_image_input(
self, image_input: Florence2ImagePixelInputs) -> torch.Tensor:
assert image_input["type"] == "pixel_values"
pixel_values = image_input["data"]
return self._encode_image(pixel_values)
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.pad_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
*,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
**kwargs,
) -> torch.Tensor:
r"""
Args:
input_ids: torch.Tensor of *decoder* input token ids.
positions: torch.Tensor of *decoder* position indices.
encoder_input_ids: torch.Tensor of *encoder* input token ids.
encoder_positions: torch.Tensor of *encoder* position indices
Returns:
Output torch.Tensor
"""
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
if encoder_input_ids.numel() > 0 or vision_embeddings is not None:
inputs_embeds = self.get_input_embeddings(encoder_input_ids,
vision_embeddings)
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
encoder_input_ids,
encoder_positions,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL.Image import Image
from torch import nn
from transformers import BatchFeature, MllamaConfig
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
from transformers.models.mllama.processing_mllama import (
MllamaProcessor, get_cross_attention_token_mask)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVCrossParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal, SupportsV0Only
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
logger = init_logger(__name__)
class MllamaImagePixelInputs(TensorSchema):
"""
Dimensions:
- batch_size: Batch size
- max_num_image: Max number of images
- max_num_chunk: Max number of chunks
- max_num_tiles: Max number of tiles per image
- num_channel: Number of channels
- height: Height
- width: Width
"""
type: Literal["pixel_values"] = "pixel_values"
data: Annotated[torch.Tensor,
TensorShape("batch_size", "max_num_image", "max_num_chunk",
"num_channel", "height", "width")]
aspect_ratio_ids: Annotated[torch.Tensor,
TensorShape("batch_size", "max_num_image")]
aspect_ratio_mask: Annotated[
torch.Tensor,
TensorShape("batch_size", "max_num_image", "max_num_tiles")]
# TODO: support LlamaImageEmbeddingInputs
def calc_token_per_chunk(image_size: int) -> int:
assert image_size % 14 == 0, "chunk size should be multiple of 14"
token_per_chunk = (image_size // 14)**2 + 1
return token_per_chunk
class MllamaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> MllamaConfig:
return self.ctx.get_hf_config(MllamaConfig)
def get_hf_processor(self, **kwargs: object) -> MllamaProcessor:
return self.ctx.get_hf_processor(MllamaProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_token_per_chunk_from_config(self) -> int:
image_size = self.get_hf_config().vision_config.image_size
return calc_token_per_chunk(image_size)
def get_num_tiles_per_image(self, image_height: int,
image_width: int) -> int:
vision_config = self.get_hf_config().vision_config
max_num_tiles = vision_config.max_num_tiles
image_size = vision_config.image_size
tiled_height, tiled_width = get_optimal_tiled_canvas(
image_height,
image_width,
max_num_tiles,
tile_size=image_size,
)
num_tiles_height = tiled_height // image_size
num_tiles_width = tiled_width // image_size
return num_tiles_height * num_tiles_width
def get_image_size_with_most_features(self) -> ImageSize:
vision_config = self.get_hf_config().vision_config
image_size = vision_config.image_size
max_num_tiles = vision_config.max_num_tiles
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=max_num_tiles * image_size, width=image_size)
class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
image_token = processor.image_token
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
):
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
mm_uuids: Optional[MultiModalUUIDDict] = None,
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt,
mm_data,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_uuids=mm_uuids)
image_token_id = self.info.get_hf_config().image_token_index
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count(image_token_id)
image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images:
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})")
# Given prompt: <IMG0> P0 P1 <IMG1> <IMG2> P3 P4 D5 D6...., (P-prefill, D-decode) # noqa: E501
# P0 & P1 do cross attention with placeholder of <IMG0>
# P3 P4 D5 D6 do cross attention with placeholder of <IMG1> and <IMG2>
# Example input to encoder and decoder:
# {
# 'encoder': {
# 'type': 'token',
# 'prompt_token_ids': [128256, 128256, ..., 128256],
# 'prompt': '<|image|><|image|>...<|image|>',
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# 'decoder': {
# 'type': 'token',
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
# },
# }
if mm_data:
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tokens for those images.
token_per_chunk = self.info.get_token_per_chunk_from_config()
num_decode_images = self._get_num_image_in_last_group(
mm_inputs["prompt_token_ids"])
num_encode_images = num_images - num_decode_images
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
num_tiles = mm_inputs["mm_kwargs"].get_data()["num_tiles"]
decode_tiles = num_tiles[num_encode_images:num_images].sum().item()
num_tokens = decode_tiles * token_per_chunk
mm_inputs["encoder_prompt_token_ids"] = [image_token_id
] * num_tokens
mm_inputs["encoder_prompt"] = image_token * num_tokens
return mm_inputs
def _get_num_image_in_last_group(self, prompt_token_ids: list[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == self.info.get_hf_config().image_token_index:
num_images += 1
elif num_images > 0:
break
return num_images
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
tokenizer = self.info.get_tokenizer()
if mm_data:
num_tiles = [
self.info.get_num_tiles_per_image(img.height, img.width)
for img in mm_data["images"]
]
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs)
processed_outputs["num_tiles"] = torch.tensor(num_tiles)
for k in ('pixel_values', 'aspect_ratio_ids', "aspect_ratio_mask"):
processed_outputs[k] = processed_outputs[k].squeeze(0)
processed_token_ids = processed_outputs.pop("input_ids")
start_idx, end_idx = 0, processed_token_ids.size(1)
processed_prompt_text = tokenizer.decode(processed_token_ids[0])
hf_processor = self.info.get_hf_processor()
bos_token = hf_processor.bos_token
# Remove the bos_token from the start of prompt,
# because we all know there would be image_token.
if processed_prompt_text.startswith(bos_token):
start_idx += 1
# Remove the bos_token from the end of prompt,
# because text is empty in this case.
if processed_prompt_text.endswith(bos_token):
end_idx -= 1
processed_outputs[
"input_ids"] = processed_token_ids[:, start_idx:end_idx]
else:
processed_outputs = tokenizer(prompt,
add_special_tokens=False,
return_tensors="pt")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
aspect_ratio_ids=MultiModalFieldConfig.batched("image"),
aspect_ratio_mask=MultiModalFieldConfig.batched("image"),
num_tiles=MultiModalFieldConfig.batched("image"),
)
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
data = mm_data.get("image", [])
num_images = 1 if isinstance(data, Image) else len(data)
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index
def get_replacement_mllama(item_idx):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
num_tile = self.info.get_num_tiles_per_image(
image_height=image_size.height,
image_width=image_size.width,
)
num_tokens = num_tile * token_per_chunk
return [image_token_id] * num_tokens
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_mllama,
)
]
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask = attention_mask.reshape(batch_size,
max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(
-1, -2) * torch.finfo(dtype).min
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, tuple[int, int]],
stride: Union[int, tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x, _ = self._linear(x)
return x
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = True):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.is_gated = is_gated
self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.hidden_size)
if is_gated:
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
self.hidden_size)
if self.is_gated:
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size)**2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(torch.zeros(1))
# position embedding
position_embedding = torch.randn(self.num_patches, self.hidden_size)
self.embedding = nn.Parameter(self.scale * position_embedding)
# tile position embedding
self.tile_embedding = nn.Embedding(
self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.num_patches * self.hidden_size)
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
# position embeddings
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
hidden_state = hidden_state + gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
gated_tile_position_embedding = self.gate.tanh(
) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
tensor_parallel_size = get_tp_group().world_size
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // tensor_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_local_heads, self.head_dim,
1.0 / math.sqrt(self.head_dim))
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
# Use unified MultiHeadAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
is_gated: bool = False,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state,
attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
num_layers: int = 32,
is_gated: bool = False,
output_hidden_states=None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config,
quant_config=quant_config,
is_gated=is_gated,
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(num_layers)
])
self.output_hidden_states = output_hidden_states or []
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
return hidden_states, encoder_states
class MllamaVisionModel(nn.Module):
def __init__(
self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.in_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = ColumnParallelConv2dPatch(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
config)
self.pre_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
self.post_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices,
prefix=f"{prefix}.transformer",
)
self.global_transformer = MllamaVisionEncoder(
config,
quant_config,
config.num_global_layers,
is_gated=True,
prefix=f"{prefix}.global_transformer",
)
def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1,
hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(self, pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, \
height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels,
height, width)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1)
# patch embedding
patch_embeds = self.patch_embedding(
pixel_values.to(self.layernorm_pre.weight.dtype))
hidden_state = patch_embeds
hidden_state = ps.get_tp_group().all_gather(hidden_state)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, -1, dim)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, num_patches, dim)
hidden_state = self.gated_positional_embedding(hidden_state,
aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0, 0, 0, num_padding_patches
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
attention_mask = aspect_ratio_mask.reshape(
batch_size * num_concurrent_media, -1)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.layernorm_pre.weight.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
dim)
output = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states,
dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches), dim)
hidden_state = self.global_transformer(
hidden_state, attention_mask=attention_mask)[0]
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
num_tiles, num_patches, dim)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media, num_tiles,
num_patches + num_padding_patches, -1)
intermediate_hidden_states = intermediate_hidden_states[:, :, :
slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
dim=-1)
return hidden_state
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding._linear.weight' in name:
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
class MllamaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.pipeline_parallel_rank = get_pp_group().rank_in_group
self.tensor_parallel_size = get_tp_group().world_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_local_heads = self.num_heads // self.tensor_parallel_size
self.num_local_key_value_heads = \
self.num_key_value_heads // self.tensor_parallel_size
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
self.qkv_proj = QKVCrossParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_DECODER,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[list[tuple[int, int]]],
cross_attention_states: Optional[torch.Tensor],
) -> torch.Tensor:
q, k, v = self.qkv_proj(hidden_states, cross_attention_states)
if cross_attention_states is not None:
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
if attention_mask is not None:
output = self._attention_with_mask(q, k, v, attention_mask,
kv_range_for_decode)
else:
output = self.attn(
q.view(-1, self.num_local_heads * self.head_dim), k, v)
out, _ = self.o_proj(output)
return out
def _attention_with_mask(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_mask: torch.Tensor,
kv_range_for_decode: list[tuple[int, int]],
) -> torch.Tensor:
kv_cache = self.attn.kv_cache[self.pipeline_parallel_rank]
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
# Skip writing kv-cache for the initial profiling run.
# TODO (NickLucche) replace with custom attn bias and use standard attn
if len(kv_cache.shape) > 1:
i = torch.ones(1, dtype=torch.float32)
if self.attn.backend in (_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1):
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
torch.ops._C_cache_ops.reshape_and_cache_flash(
cached_k,
cached_v,
kv_cache[0],
kv_cache[1],
attn_metadata.
cross_slot_mapping, # type: ignore[union-attr]
"auto",
i,
i,
)
elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH,
_Backend.TORCH_SDPA):
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", i, i)
else:
raise ValueError(
f"Unsupported Attention backend {self.attn.backend} "
"enum found. Expected the Attention backend to be "
"FLASH_ATTN, FLASH_ATTN_VLLM_V1, "
"XFORMERS or TORCH_SDPA.")
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim).contiguous()
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len,
self.head_dim).contiguous()
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
is_causal=False)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
return output
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(
self,
config: config_mllama.MllamaTextConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.cross_attn",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh(
) * hidden_states
return hidden_states
class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
config.hidden_size)
self.cross_attention_layers = config.cross_attention_layers
layers = []
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(
config,
layer_idx,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(
config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}",
))
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
if idx in self.cross_attention_layers:
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
)
else:
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
residual=None,
)
hidden_states = hidden_states + residual
hidden_states = self.norm(hidden_states)
return hidden_states
class MllamaForCausalLM(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "language_model"
_no_split_modules = [
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[list[tuple[int, int]]],
full_text_row_masked_out_mask: Optional[tuple[torch.Tensor,
torch.Tensor]],
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params: set[str] = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
orig_name = name
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
logger.debug("Missing name %s, orig name %s", name,
orig_name)
continue
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params
@MULTIMODAL_REGISTRY.register_processor(MllamaMultiModalProcessor,
info=MllamaProcessingInfo,
dummy_inputs=MllamaDummyInputsBuilder)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsV0Only):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.vision_model.": "vision_model.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.language_model.": "language_model.model.",
"lm_head.": "language_model.lm_head.",
},
orig_to_new_suffix={
"patch_embedding.weight": "patch_embedding._linear.weight",
},
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
return "<|image|>"
raise ValueError("Only image modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MllamaConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = \
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.image_token_id = config.image_token_index
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
prefix=maybe_prefix(
prefix, "vision_model"))
self.language_model = MllamaForCausalLM(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
quant_config=quant_config,
gather_output=True,
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.lm_head,
hidden_states, sampling_metadata)
return logits
def unpack_data(self,
image_data: Union[list[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor):
# torch.Tensor
return image_data
else:
assert isinstance(
image_data[0],
torch.Tensor), "Image data is not properly batched."
# list[torch.Tensor]
bsz = len(image_data)
max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:]
for data in image_data:
cur_trailing_dims = data.shape[1:]
assert cur_trailing_dims == trailing_dims
output_tensor = torch.full((bsz, max_length, *trailing_dims),
padding_value,
dtype=image_data[0].dtype,
device=image_data[0].device)
for i, t in enumerate(image_data):
output_tensor[i, :t.size(0)] = t
return output_tensor
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - list[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
aspect_ratio_ids: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[list[list[torch.Tensor]],
list[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
return MllamaImagePixelInputs(
type="pixel_values",
data=self.unpack_data(pixel_values),
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def _get_and_validate_encoder_lens(
self,
encoder_seq_lens: list[int],
num_tiles: list[list[int]],
num_tokens_per_tile: int,
) -> list[int]:
# Get the actual number of encoder tokens for each sample.
# Because attn_metadata.encoder_seq_lens only counts the last
# group of images for each sample, which is used to cheat the
# block manager to allocate blocks for those images only.
# See MllamaMultiModalProcessor for more details.
actual_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]
# remove 0 encoder len entries for text-only requests for these
# assertions
attn_metadata_lens = [x for x in encoder_seq_lens if x > 0]
assert len(actual_encoder_seq_lens) == len(attn_metadata_lens)
for actual_len, last_group_len in zip(actual_encoder_seq_lens,
attn_metadata_lens):
assert actual_len >= last_group_len
return actual_encoder_seq_lens
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: list[int]):
cross_attention_states_flat = torch.zeros(
sum(actual_encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat
return cross_attention_states
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_cross_attention_states(
self,
image_inputs: MllamaImagePixelInputs,
attn_metadata: AttentionMetadata,
actual_encoder_seq_lens: list[int],
) -> tuple[torch.Tensor]:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states = self.flat_encoder_result(
cross_attention_states, attn_metadata, actual_encoder_seq_lens)
return cross_attention_states
def get_cross_attention_mask(
self,
input_ids: torch.Tensor,
attn_metadata: AttentionMetadata,
num_tiles: list[list[int]],
num_tokens_per_tile: int,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
get_cross_attention_token_mask(t, self.image_token_id)
for t in batch_token_ids
]
# Skip generating cross-attention mask if all samples
# are text-only or have only 1 leading image.
if skip_attention_mask(sparse_mask):
return None, None
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device, dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]
return cross_attention_mask, kv_range_for_decode
def get_full_text_row_masked_out_mask(
self,
attn_metadata: AttentionMetadata,
device: torch.device,
) -> torch.Tensor:
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
device)
return full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
**kwargs: object,
) -> Union[CausalLMOutputWithPast]:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
# For 1) text-only prefill and decode, 2) image-present decode.
if image_inputs is None:
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor
!= 0).reshape(-1, 1).to(input_ids.device)
skip_cross_attention = attn_metadata.max_encoder_seq_len == 0
# For image-present prefill.
else:
skip_cross_attention = False
num_tiles = [t.tolist() for t in kwargs.pop("num_tiles")]
num_tokens_per_tile = calc_token_per_chunk(self.image_size)
actual_encoder_seq_lens = self._get_and_validate_encoder_lens(
attn_metadata.encoder_seq_lens,
num_tiles,
num_tokens_per_tile,
)
cross_attention_states = self.get_cross_attention_states(
image_inputs, attn_metadata, actual_encoder_seq_lens)
full_text_row_masked_out_mask = \
self.get_full_text_row_masked_out_mask(
attn_metadata, input_ids.device)
cross_attention_mask, kv_range_for_decode = \
self.get_cross_attention_mask(
input_ids, attn_metadata, num_tiles,
num_tokens_per_tile, cross_attention_states.dtype)
outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
skip_cross_attention=skip_cross_attention,
)
return outputs
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_model")
def skip_attention_mask(sparse_mask: list[list[int]]) -> bool:
for mask in sparse_mask:
# Skip text-only samples.
if len(mask) == 0:
continue
# If the sample contains more than 1 images,
# we can't skip mask.
if len(mask) != 1:
return False
# If the sample contains only 1 image,
# but the image is not the leading one,
# we can't skip mask.
if mask[0][0] != 0 or mask[0][1] != -1:
return False
return True
def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: list[list[list[int]]],
num_tiles: list[list[int]],
lengths: list[int],
) -> tuple[np.ndarray, list[tuple[int, int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means that the i-th image will
# use tiles[start, end] for cross-attention decoding.
tile_range_for_decode = []
seq_start = 0
tile_start = 0
# sparse_mask has an [] entry for each sequence that does not have images,
# but num_tiles does not have these entries...
num_tiles_idx = 0
for masks, length in zip(sparse_mask, lengths):
if len(masks) == 0:
# Text only
continue
tiles = num_tiles[num_tiles_idx]
num_tiles_idx += 1
ts, td = -1, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == -1:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
assert ts != -1
assert td != 0
tile_range_for_decode.append((ts, ts + td))
seq_start += length
assert num_tiles_idx == len(num_tiles)
return dense_mask, tile_range_for_decode
def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
......@@ -147,10 +147,6 @@ _TEXT_GENERATION_MODELS = {
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"MBartForConditionalGeneration": ("bart", "MBartForConditionalGeneration"),
}
_EMBEDDING_MODELS = {
......@@ -237,6 +233,7 @@ _MULTIMODAL_MODELS = {
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501
......@@ -263,16 +260,12 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
# [Encoder-decoder]
"DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"),
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501
"SkyworkR1VChatModel": ("skyworkr1v", "SkyworkR1VChatModel"),
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
......
......@@ -209,7 +209,7 @@ class MultiModalProfiler(Generic[_I]):
if processor.pad_dummy_encoder_prompt:
num_tokens_to_pad = max(total_len, seq_len) - total_len
encoder_prompt_token_ids.extend([0] * num_tokens_to_pad)
# NOTE: Whisper and Donut allows total_len > seq_len.
# NOTE: Whisper allows total_len > seq_len.
elif total_len > seq_len and not envs.VLLM_USE_V1:
# `max_num_batched_tokens` is defined by `SchedulerConfig`
logger.warning_once(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment