Unverified Commit 770ec602 authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

[Model] Add support for the multi-modal Llama 3.2 model (#8811)


Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
Co-authored-by: default avatarSimon Mo <simon.mo@hey.com>
Co-authored-by: default avatarRoger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 4f1ba084
......@@ -254,6 +254,11 @@ Multimodal Language Models
- Image\ :sup:`+`
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code:`MllamaForConditionalGeneration`
- Llama 3.2
- Image
- :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- Image\ :sup:`E`
......
......@@ -242,6 +242,29 @@ def run_qwen2_vl(question, modality):
return llm, prompt, stop_token_ids
# LLama
def run_mllama(question, modality):
assert modality == "image"
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
# Note: The default setting of max_num_seqs (256) and
# max_model_len (131072) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# The configuration below has been confirmed to launch on a
# single H100 GPU.
llm = LLM(
model=model_name,
max_num_seqs=16,
enforce_eager=True,
)
prompt = f"<|image|><|begin_of_text|>{question}"
stop_token_ids = None
return llm, prompt, stop_token_ids
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
......@@ -256,6 +279,7 @@ model_example_map = {
"internvl_chat": run_internvl,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"mllama": run_mllama,
}
......
......@@ -38,7 +38,7 @@ chat_completion_from_url = client.chat.completions.create(
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
},
{
"type": "image_url",
......@@ -75,7 +75,7 @@ chat_completion_from_base64 = client.chat.completions.create(
"content": [
{
"type": "text",
"text": "Whats in this image?"
"text": "What's in this image?"
},
{
"type": "image_url",
......
......@@ -4,7 +4,7 @@ numpy < 2.0.0
requests
tqdm
py-cpuinfo
transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox.
transformers >= 4.45.0 # Required for Llama 3.2.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi < 0.113.0; python_version < '3.9'
......
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
BatchEncoding)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import multi_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|image|><|begin_of_text|>The meaning of the image is",
"cherry_blossom":
"<|image|><|begin_of_text|>The city is",
})
text_only_prompts = [
"The color of the sky is blue but sometimes it can also be",
]
models = [
"meta-llama/Llama-3.2-11B-Vision-Instruct",
]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id
hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != image_token_id or output_ids[idx - 1] != image_token_id
]
assert output_str[0] == " "
hf_output_str = output_str[1:]
if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[
prompt if size is not None else text_only_prompts[0]
for size in sizes
],
[
image.resize(size) if size is not None else None
for size in sizes
],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if len(sizes) == 0:
inputs_per_image.append(
(text_only_prompts, [None] * len(text_only_prompts)))
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
_run_test(hf_runner,
vllm_runner,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend)
def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
inputs: List[Tuple[List[str], PromptImageInput]],
model: str,
*,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test are from IMAGE_ASSETS.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding MultiModalConfig as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
max_num_seqs=16,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
}) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
def process(hf_inputs: BatchEncoding):
return hf_inputs
from transformers import AutoConfig
from transformers.models.mllama import MllamaConfig as MllamaConfigHf
# use transformer's MllamaConfig for hf_runner
# and vllm's MllamaConfig for vllm_runner
AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True)
with hf_runner(model,
dtype=dtype,
postprocess_inputs=process,
auto_cls=AutoModelForVision2Seq) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs
]
from vllm.transformers_utils.configs.mllama import MllamaConfig
AutoConfig.register("mllama", MllamaConfig, exist_ok=True)
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype,
max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
],
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=2,
)
......@@ -576,7 +576,9 @@ class ModelConfig:
@property
def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False)
return getattr(self.hf_config, "is_encoder_decoder", False) or (
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:
......
......@@ -1734,7 +1734,11 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: Union[LLMInputs,
EncoderDecoderLLMInputs]):
if self.is_encoder_decoder_model():
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
else:
prompt_ids = inputs.get("prompt_token_ids")
......
......@@ -159,6 +159,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"
if model_type == "mllama":
return "<|image|>"
if model_type == "qwen2_vl":
return "<|vision_start|><|image_pad|><|vision_end|>"
......@@ -358,6 +360,7 @@ _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
def _parse_chat_message_content_parts(
......@@ -368,7 +371,11 @@ def _parse_chat_message_content_parts(
texts: List[str] = []
mm_parser = mm_tracker.create_parser()
keep_multimodal_content = \
mm_tracker._model_config.hf_config.model_type in \
MODEL_KEEP_MULTI_MODAL_CONTENT
has_image = False
for part in parts:
part_type = part["type"]
if part_type == "text":
......@@ -383,6 +390,7 @@ def _parse_chat_message_content_parts(
"will be ignored.")
mm_parser.parse_image(image_url["url"])
has_image = True
elif part_type == "audio_url":
audio_url = _AudioParser(part)["audio_url"]
......@@ -394,12 +402,20 @@ def _parse_chat_message_content_parts(
raise NotImplementedError(f"Unknown part type: {part_type}")
text_prompt = "\n".join(texts)
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
if keep_multimodal_content:
text_prompt = "\n".join(texts)
role_content = [{'type': 'text', 'text': text_prompt}]
if has_image:
role_content = [{'type': 'image'}] + role_content
return [ConversationMessage(role=role,
content=role_content)] # type: ignore
else:
mm_placeholder_counts = mm_parser.mm_placeholder_counts()
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(
mm_placeholder_counts, text_prompt)
return [ConversationMessage(role=role, content=text_prompt)]
# No need to validate using Pydantic again
......
......@@ -309,6 +309,8 @@ class OpenAIServingChat(OpenAIServing):
async for res in result_generator:
if res.prompt_token_ids is not None:
num_prompt_tokens = len(res.prompt_token_ids)
if res.encoder_prompt_token_ids is not None:
num_prompt_tokens += len(res.encoder_prompt_token_ids)
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
......
......@@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
available.
"""
encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
"""
Optional multi-modal data to pass to the encoder model,
if the model supports it.
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
......
......@@ -128,6 +128,7 @@ class InputPreprocessor:
def _prepare_decoder_input_ids_for_generation(
self,
decoder_input_ids: Optional[List[int]],
force_bos: bool = True,
) -> List[int]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
......@@ -157,8 +158,8 @@ class InputPreprocessor:
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
if (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
if force_bos and (len(decoder_input_ids) == 0
or decoder_input_ids[0] != decoder_start_token_id):
decoder_input_ids = [decoder_start_token_id] + decoder_input_ids
return decoder_input_ids
......@@ -295,18 +296,25 @@ class InputPreprocessor:
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
if encoder_mm_data is not None or decoder_mm_data is not None:
raise ValueError("Multi-modal encoder-decoder models are "
"not supported yet")
if decoder_mm_data is not None:
raise ValueError(
"Multi-modality decoder inputs of encoder-decoder models are "
"not supported yet")
decoder_prompt_ids = (
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
# For Multi-Modal models (e.g., mllama), the text input can be
# <|image|><|begin_of_text|>hello world. And we should not add
# another <|begin_of_text|> to the beginning.
decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation(
decoder_prompt_ids,
force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
return EncoderDecoderLLMInputs(
prompt_token_ids=decoder_prompt_ids,
prompt=decoder_prompt,
multi_modal_data=decoder_mm_data,
encoder_prompt_token_ids=encoder_prompt_ids,
encoder_prompt=encoder_prompt,
encoder_multi_modal_data=encoder_mm_data,
)
def _process_encoder_decoder_prompt(
......
......@@ -112,6 +112,8 @@ class InputRegistry:
def __init__(self) -> None:
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
DummyDataFactory] = {}
self._dummy_encoder_factories_by_model_type: Dict[
Type[nn.Module], DummyDataFactory] = {}
self._input_processors_by_model_type: Dict[Type[nn.Module],
InputProcessor] = {}
......@@ -162,11 +164,44 @@ class InputRegistry:
return self._dummy_factories_by_model_type \
.get(model_cls, self._default_dummy_data_factory)
def register_dummy_encoder_data(self, factory: DummyDataFactory):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_encoder_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]):
if model_cls in self._dummy_encoder_factories_by_model_type:
dummy_factory = self._dummy_encoder_factories_by_model_type[
model_cls]
else:
logger.warning(
"No dummy encoder data factory registered to %s. "
"Using the dummy data factory for the model instead.",
model_cls)
dummy_factory = self._get_dummy_data_factory(model_cls)
return dummy_factory
def dummy_data_for_profiling(
self,
model_config: "ModelConfig",
seq_len: int,
mm_registry: "MultiModalRegistry",
is_encoder_data: bool = False,
) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]:
"""
Create dummy data for profiling the memory usage of a model.
......@@ -184,8 +219,10 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
dummy_factory = self._get_dummy_data_factory(model_cls)
if is_encoder_data:
dummy_factory = self._get_dummy_encoder_data_factory(model_cls)
else:
dummy_factory = self._get_dummy_data_factory(model_cls)
mm_counts = mm_registry.get_mm_limits_per_prompt(model_config)
mm_processor_kwargs = get_allowed_kwarg_only_overrides(
dummy_factory, overrides=model_config.mm_processor_kwargs)
......@@ -196,10 +233,15 @@ class InputRegistry:
# Having more tokens is over-conservative but otherwise fine
num_tokens = seq_data.prompt_token_ids
assert len(num_tokens) >= seq_len, (
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if len(num_tokens) < seq_len:
if is_encoder_data:
logger.warning(
"Expected at least %d dummy encoder tokens for profiling, "
"but found %d tokens instead.", seq_len, len(num_tokens))
else:
raise AssertionError(
f"Expected at least {seq_len} dummy tokens for profiling, "
f"but found {len(num_tokens)} tokens instead.")
if mm_data is not None:
for k, v in mm_data.items():
num_items = len(v) if isinstance(v, list) else 1
......
......@@ -101,6 +101,8 @@ _MULTIMODAL_MODELS = {
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
......
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers.models.mllama.configuration_mllama as config_mllama
from PIL import Image
from torch import nn
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.models.mllama.image_processing_mllama import (
get_optimal_tiled_canvas)
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
MLLAMA_IMAGE_TOKEN = "<|image|>"
class MllamaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: """
"""(batch_size, max_num_image, max_num_chunk, num_channel, height, width)"""
aspect_ratio_ids: torch.Tensor
"""Shape: `(batch_size, max_num_image)`"""
aspect_ratio_mask: torch.Tensor
"""Shape: `(batch_size, max_num_image, max_num_tiles)`"""
# TODO: support LlamaImageEmbeddingInputs
def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt
if llm_inputs.get("prompt") is None:
llm_inputs["prompt"] = llm_inputs["encoder_prompt"]
llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"]
# process multi-modal data
assert "decoder_multi_modal_data" not in llm_inputs, \
"multi-modal data should be put in encoder message of mllama"
multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data \
or multi_modal_data["image"] is None:
# text-only
llm_inputs["encoder_prompt"] = ""
llm_inputs["encoder_prompt_token_ids"] = []
llm_inputs["encoder_multi_modal_data"] = {}
return llm_inputs
# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
hf_config = ctx.model_config.hf_config
num_tiles = 0
for image in multi_modal_data["image"]:
width, height = image.size
tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
image_height=height,
image_width=width,
max_image_tiles=hf_config.vision_config.max_num_tiles,
tile_size=tile_size,
)
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
# set encoder prompt based on num_tiles
assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
num_tokens = num_tiles * token_per_chunk
llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID
] * num_tokens
return llm_inputs
def get_max_mllama_image_tokens(ctx: InputContext) -> int:
hf_config = ctx.model_config.hf_config
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
return hf_config.vision_config.max_num_tiles * token_per_chunk
def dummy_decoder_seq_data(seq_len: int, num_images: int):
# <|image|> * num_images + 0 * (seq_len - num_images)
assert seq_len >= num_images, \
"seq_len should be greater than or equal to num_images"
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images)
return SequenceData(token_ids)
def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
num_tokens = get_max_mllama_image_tokens(ctx) * num_images
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[MLLAMA_IMAGE_TOKEN_ID]) * num_tokens
return SequenceData(token_ids)
def dummy_image(num_images: int, ):
width = height = 1024
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_decoder_seq_data(seq_len, num_images), None
def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images)
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1,
1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length)
attention_mask = attention_mask.reshape(batch_size,
max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(
-1, -2) * torch.finfo(dtype).min
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
class ColumnParallelConv2dPatch(torch.nn.Module):
"""Conv2D Patching layer with model parallelism.
Column parallel over unfolded input.
Arguments:
in_channels: Input channels.
out_channels: Output channels.
kernel_size: Size of convolution kernel.
stride (default 1): Stride for convolution.
bias (default False): Use bias in Conv2d.
Input: (bsz, in_channels, width, height)
Output: (bsz, num_tokens, out_channels)
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
bias: bool = False,
) -> None:
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
self._linear = ColumnParallelLinear(
in_channels * kernel_size[0] * kernel_size[1],
out_channels,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self._unfold(x)
x = x.permute(0, 2, 1)
x, _ = self._linear(x)
return x
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = True):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.is_gated = is_gated
self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.hidden_size)
if is_gated:
self.gate = nn.Parameter(torch.zeros(1))
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1,
self.hidden_size)
if self.is_gated:
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size)**2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(torch.zeros(1))
# position embedding
position_embedding = torch.randn(self.num_patches, self.hidden_size)
self.embedding = nn.Parameter(self.scale * position_embedding)
# tile position embedding
self.tile_embedding = nn.Embedding(
self.max_aspect_ratio_id + 1,
self.max_num_tiles * self.num_patches * self.hidden_size)
def forward(self, hidden_state: torch.Tensor,
aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
# position embeddings
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
hidden_state = hidden_state + gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size)
gated_tile_position_embedding = self.gate.tanh(
) * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
self.embed_dim = config.hidden_size
self.num_heads = config.attention_heads
self.head_dim = config.hidden_size // config.attention_heads
self.num_local_heads = self.num_heads // model_parallel_size
self.q_size = self.num_local_heads * self.head_dim
self.kv_size = self.num_local_heads * self.head_dim
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
bias=False,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_state)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = q.view(q.shape[0], q.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
k = k.view(k.shape[0], k.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
v = v.view(v.shape[0], v.shape[1], self.num_local_heads,
self.head_dim).transpose(1, 2)
# TODO: remove padding in image encoder
attn_output = F.scaled_dot_product_attention(q,
k,
v,
attn_mask=attention_mask,
dropout_p=0.0)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(attn_output.shape[0],
attn_output.shape[1], -1)
output, _ = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config)
self.mlp = CLIPMLP(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size,
eps=config.norm_eps)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state,
attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
num_layers=32,
is_gated=False,
output_hidden_states=None):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
MllamaVisionEncoderLayer(config, is_gated)
for _ in range(num_layers)
])
self.output_hidden_states = output_hidden_states or []
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Union[Tuple, BaseModelOutput]:
encoder_states = ()
for i, encoder_layer in enumerate(self.layers):
if i in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
if len(self.layers) - 1 in self.output_hidden_states:
encoder_states = encoder_states + (hidden_states, )
return hidden_states, encoder_states
class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.in_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size)**2 + 1
self.scale = config.hidden_size**-0.5
self.patch_embedding = ColumnParallelConv2dPatch(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.class_embedding = nn.Parameter(self.scale *
torch.randn(self.hidden_size))
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
config)
self.pre_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
self.post_tile_positional_embedding = \
MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True)
# layer norms
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
self.layernorm_post = nn.LayerNorm(self.hidden_size)
# encoders
self.transformer = MllamaVisionEncoder(
config,
config.num_hidden_layers,
is_gated=False,
output_hidden_states=config.intermediate_layers_indices)
self.global_transformer = MllamaVisionEncoder(config,
config.num_global_layers,
is_gated=True)
def apply_class_embedding(self,
hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1,
hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(self, pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
aspect_ratio_mask: torch.Tensor) -> torch.Tensor:
batch_size, num_concurrent_media, num_tiles, num_channels, \
height, width = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels,
height, width)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1)
# patch embedding
patch_embeds = self.patch_embedding(
pixel_values.to(self.layernorm_pre.weight.dtype))
hidden_state = patch_embeds
hidden_state = ps.get_tp_group().all_gather(hidden_state)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, -1, dim)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles, num_patches, dim)
hidden_state = self.gated_positional_embedding(hidden_state,
aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0, 0, 0, num_padding_patches
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
attention_mask = aspect_ratio_mask.reshape(
batch_size * num_concurrent_media, -1)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.layernorm_pre.weight.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1,
dim)
output = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
hidden_state, intermediate_hidden_states = output[0], output[1]
intermediate_hidden_states = torch.stack(intermediate_hidden_states,
dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches), dim)
hidden_state = self.global_transformer(
hidden_state, attention_mask=attention_mask)[0]
hidden_state = hidden_state.reshape(batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(batch_size, num_concurrent_media,
num_tiles, num_patches, dim)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media, num_tiles,
num_patches + num_padding_patches, -1)
intermediate_hidden_states = intermediate_hidden_states[:, :, :
slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states],
dim=-1)
return hidden_state
class MllamaTextRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
MllamaTextRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config: Optional[config_mllama.MllamaTextConfig] = None,
layer_idx: Optional[int] = None,
):
super().__init__()
self.config = config
self.model_parallel_size = get_tensor_model_parallel_world_size()
self.num_heads = self.config.num_attention_heads
self.num_local_heads = self.num_heads // self.model_parallel_size
self.num_key_value_heads = self.config.num_key_value_heads
self.num_local_key_value_heads = \
self.num_key_value_heads // self.model_parallel_size
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_local_size = self.num_local_heads * self.head_dim
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
# TODO: change to Q/KV separate linear after #7448 is merged
self.qkv_proj = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.num_heads,
self.num_key_value_heads,
bias=False,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
)
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_local_heads,
self.head_dim,
self.scaling,
self.num_local_key_value_heads,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv_dec, _ = self.qkv_proj(hidden_states)
q, _, _ = qkv_dec.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
if cross_attention_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(cross_attention_states)
_, k, v = qkv_enc.split(
[self.q_local_size, self.kv_local_size, self.kv_local_size],
dim=-1)
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
k = self.k_norm(k)
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
out, _ = self.o_proj(output)
return out
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention
and feedforward."""
def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \
-> None:
super().__init__()
self.layer_idx = layer_idx
self.cross_attn = MllamaTextCrossAttention(
config=config,
layer_idx=layer_idx,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
self.mlp = LlamaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_attn_gate.tanh(
) * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = full_text_row_masked_out_mask * hidden_states
hidden_states = residual + self.cross_attn_mlp_gate.tanh(
) * hidden_states
return hidden_states
class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
config.hidden_size)
self.cross_attention_layers = config.cross_attention_layers
layers = []
for layer_idx in range(config.num_hidden_layers):
if layer_idx in self.cross_attention_layers:
layers.append(
MllamaCrossAttentionDecoderLayer(config, layer_idx))
else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append(
LlamaDecoderLayer(config,
cache_config=cache_config,
quant_config=quant_config))
self.layers = nn.ModuleList(layers)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
class MllamaForCausalLM(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "language_model"
_no_split_modules = [
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(self, config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig]):
super().__init__()
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
)
def forward(
self,
input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
skip_cross_attention: bool,
) -> torch.Tensor:
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
return hidden_states
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: config_mllama.MllamaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = \
config.pad_token_id if config.pad_token_id is not None else -1
self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config)
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
self.sampler = Sampler()
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalInputs.batch, so pixel_values here can be:
# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
# with shape (bs, num_image, num_tiles, 3, image_res, image_res)
pixel_values: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"pixel_values", None)
image_embeds: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"image_embeds", None)
aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_ids", None)
aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]],
List[torch.Tensor],
torch.Tensor]] = kwargs.pop(
"aspect_ratio_mask", None)
if pixel_values is None and image_embeds is None:
return None
if pixel_values is not None and image_embeds is not None:
raise ValueError(
"Both pixel values and image embeds are provided.")
if pixel_values is not None:
assert aspect_ratio_ids is not None
assert aspect_ratio_mask is not None
max_num_images = max([len(x[0]) for x in pixel_values])
if max_num_images == 0:
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
bsz,
max_num_images,
max_num_tiles,
3,
self.image_size,
self.image_size,
dtype=torch.float32,
device=device,
)
out_ar_ids = torch.ones(bsz,
max_num_images,
dtype=torch.int64,
device=device)
out_ar_mask = torch.zeros(bsz,
max_num_images,
max_num_tiles,
dtype=torch.int64,
device=device)
for b in range(len(pixel_values)):
_num_tiles = []
for i in range(len(pixel_values[b][0])):
img = pixel_values[b][0][i]
out_images[b, i, :img.shape[0]] = img
out_ar_ids[b, i] = aspect_ratio_ids[b][0][i]
out_ar_mask[b, i] = aspect_ratio_mask[b][0][i]
_num_tiles.append(img.shape[0])
out_num_tiles.append(_num_tiles)
return MllamaImagePixelInputs(
type="pixel_values",
data=out_images,
aspect_ratio_ids=out_ar_ids,
aspect_ratio_mask=out_ar_mask,
)
if image_embeds is not None:
raise NotImplementedError
raise AssertionError("This line should be unreachable.")
def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata):
cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(
attn_metadata.encoder_seq_lens, cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
start_pos = end_pos
cross_attention_states = cross_attention_states_flat
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(
attn_metadata.seq_lens_tensor.cpu(),
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
start_pos += seq_len
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
cross_attention_states.device)
return cross_attention_states, full_text_row_masked_out_mask
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> Union[Tuple, CausalLMOutputWithPast]:
if attn_metadata.num_prefill_tokens > 0 and \
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
else:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
aspect_ratio_ids = image_inputs['aspect_ratio_ids']
aspect_ratio_mask = image_inputs['aspect_ratio_mask']
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)
cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_mask = None
outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
skip_cross_attention=skip_cross_attention,
)
return outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
updated_params = set()
for name, loaded_weight in weights:
if 'patch_embedding.weight' in name:
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
updated_params.add(name)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict.pop(name)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -54,6 +54,12 @@ class MultiModalInputs(_MultiModalInputsBase):
if isinstance(nested_tensors, torch.Tensor):
return nested_tensors
if isinstance(nested_tensors, np.ndarray):
return torch.from_numpy(nested_tensors)
if isinstance(nested_tensors, (int, float)):
return torch.tensor(nested_tensors)
stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors]
if not is_list_of(stacked, torch.Tensor, check="all"):
# Only tensors (not lists) can be stacked.
......
......@@ -2,6 +2,7 @@ from functools import lru_cache
import torch
from PIL import Image
from transformers.image_processing_base import BatchFeature
from vllm.config import ModelConfig
from vllm.inputs.registry import InputContext
......@@ -39,6 +40,10 @@ class ImagePlugin(MultiModalPlugin):
) -> MultiModalInputs:
model_config = ctx.model_config
# Processed by input processor
if isinstance(data, BatchFeature):
return MultiModalInputs(data.data)
# PIL image
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
image_processor = self._get_hf_image_processor(model_config)
......
......@@ -13,6 +13,7 @@ from typing import Set, Tuple, Union, cast
import msgspec
import torch
from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
......@@ -21,7 +22,6 @@ from vllm.sampling_params import SamplingParams
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.multimodal.base import MultiModalDataDict
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
......@@ -471,7 +471,15 @@ class Sequence:
@property
def multi_modal_data(self) -> "MultiModalDataDict":
return self.inputs.get("multi_modal_data") or {}
if self.inputs.get("multi_modal_data") and self.inputs.get(
"encoder_multi_modal_data"):
raise ValueError(
"Multi-modal data in both encoder and decoder is not supported."
)
inputs = self.inputs
return self.inputs.get("multi_modal_data") or (cast(
EncoderDecoderLLMInputs,
inputs).get("encoder_multi_modal_data")) or {}
@property
def lora_int_id(self) -> int:
......
......@@ -22,9 +22,10 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
EAGLEConfig, ExaoneConfig,
GraniteConfig, InternVLChatConfig,
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, RWConfig,
SolarConfig, UltravoxConfig)
MllamaConfig, MLPSpeculatorConfig,
MPTConfig, NemotronConfig,
RWConfig, SolarConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.utils import check_gguf_file
......@@ -37,6 +38,10 @@ MISTRAL_CONFIG_NAME = "params.json"
logger = init_logger(__name__)
_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = {
"mllama": MllamaConfig
}
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
......@@ -55,11 +60,15 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
# Granite can be removed from here once we have upgraded to
# transformers 4.45+
"granite": GraniteConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}
for name, cls in _CONFIG_REGISTRY.items():
with contextlib.suppress(ValueError):
AutoConfig.register(name, cls)
if name in _CONFIG_REGISTRY_OVERRIDE_HF:
AutoConfig.register(name, cls, exist_ok=True)
else:
AutoConfig.register(name, cls)
class ConfigFormat(str, enum.Enum):
......
......@@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.granite import GraniteConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mllama import MllamaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.nemotron import NemotronConfig
......@@ -26,6 +27,7 @@ __all__ = [
"MedusaConfig",
"EAGLEConfig",
"ExaoneConfig",
"MllamaConfig",
"MLPSpeculatorConfig",
"NemotronConfig",
"SolarConfig",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment