Unverified Commit 1ad86acf authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Initial support for BLIP-2 (#5920)


Co-authored-by: default avatarywang96 <ywang@roblox.com>
parent ecb33a28
......@@ -7,6 +7,8 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
The following is the list of model architectures that are currently supported by vLLM.
Alongside each architecture, we include some popular models that use it.
----
Decoder-only Language Models
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. list-table::
......@@ -186,6 +188,10 @@ Vision Language Models
- Models
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
- BLIP-2
- :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc.
-
* - :code:`ChameleonForConditionalGeneration`
- Chameleon
- :code:`facebook/chameleon-7b` etc.
......@@ -215,6 +221,8 @@ Vision Language Models
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
for instructions on how to implement support for your model.
......
......@@ -106,6 +106,16 @@ def run_minicpmv(question):
return llm, prompt
# BLIP-2
def run_blip2(question):
# BLIP-2 prompt format is inaccurate on HuggingFace model repository.
# See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa
prompt = f"Question: {question} Answer:"
llm = LLM(model="Salesforce/blip2-opt-2.7b")
return llm, prompt
model_example_map = {
"llava": run_llava,
"llava-next": run_llava_next,
......@@ -114,6 +124,7 @@ model_example_map = {
"paligemma": run_paligemma,
"chameleon": run_chameleon,
"minicpmv": run_minicpmv,
"blip-2": run_blip2,
}
......
{%- for message in messages -%}
{%- if message['role'] == 'user' -%}
{{- 'Question: ' + message['content'] + ' ' -}}
{%- elif message['role'] == 'assistant' -%}
{{- 'Answer: ' + message['content'] + ' ' -}}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{- 'Answer:' -}}
{% endif %}
from typing import List, Optional, Tuple
import pytest
from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"Question: What's the content of the image? Answer:",
"cherry_blossom":
"Question: What is the season? Answer:",
})
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
_, output_str, out_logprobs = vllm_output
hf_output_str = output_str + "\n"
tokenizer = AutoTokenizer.from_pretrained(model)
hf_output_ids = tokenizer.encode(hf_output_str)
assert hf_output_ids[0] == tokenizer.bos_token_id
hf_output_ids = hf_output_ids[1:]
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", ["Salesforce/blip2-opt-2.7b"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# max_model_len should be greater than image_feature_size
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)
......@@ -77,8 +77,8 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images)
for prompts, vllm_images in inputs_per_image
images=images)
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model:
......@@ -89,9 +89,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
images=images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
......
......@@ -88,9 +88,9 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images,
images=images,
stop_token_ids=stop_token_ids)
for prompts, vllm_images in inputs_per_image
for prompts, images in inputs_per_image
]
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
......@@ -114,9 +114,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
images=images,
tokenizer=tokenizer)
for prompts, hf_images in inputs_per_image
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
......
......@@ -101,8 +101,8 @@ def run_test(
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images)
for prompts, vllm_images in inputs_per_image
images=images)
for prompts, images in inputs_per_image
]
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
......@@ -114,9 +114,9 @@ def run_test(
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
images=images,
eos_token_id=eos_token_id)
for prompts, hf_images in inputs_per_image
for prompts, images in inputs_per_image
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
......
......@@ -16,6 +16,8 @@ _GENERATION_MODELS = {
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
......@@ -56,8 +58,8 @@ _GENERATION_MODELS = {
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration":
("paligemma", "PaliGemmaForConditionalGeneration"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
......
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Union
import torch
import torch.nn as nn
from PIL import Image
from transformers import Blip2VisionConfig, BlipVisionConfig
from transformers.models.blip.modeling_blip import BlipAttention
from vllm.config import ModelConfig
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_blip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
return get_blip_image_feature_size(hf_config)
def dummy_seq_data_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
def input_processor_for_blip(
model_config: ModelConfig,
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
llm_inputs: LLMInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_blip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
self.num_patches = get_blip_num_patches(image_size=self.image_size,
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(
torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
position_embeds = self.position_embedding.to(target_dtype)
embeddings = embeddings + position_embeds[:, :embeddings.size(1), :]
return embeddings
class BlipMLP(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class BlipEncoderLayer(nn.Module):
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.mlp = BlipMLP(config, quant_config=quant_config)
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class BlipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`BlipEncoderLayer`].
Args:
config: BlipConfig
"""
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList([
BlipEncoderLayer(config=config, quant_config=quant_config)
for _ in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
return hidden_states
class BlipVisionModel(nn.Module):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
def __init__(self,
config: BlipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None):
super().__init__()
self.config = config
self.embeddings = BlipVisionEmbeddings(config)
self.encoder = BlipEncoder(
config=config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.encoder(inputs_embeds=hidden_states)
return self.post_layernorm(hidden_states)
This diff is collapsed.
......@@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
for _ in range(config.num_hidden_layers)
])
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
pos_embeds = self.embed_positions(positions)
if self.project_in is not None:
inputs_embeds, _ = self.project_in(inputs_embeds)
......@@ -272,14 +277,22 @@ class OPTModel(nn.Module):
super().__init__()
self.decoder = OPTDecoder(config, cache_config, quant_config)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, attn_metadata)
return self.decoder(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
class OPTForCausalLM(nn.Module):
......
import sys
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from typing import (Any, Callable, Dict, List, Optional, Type, TypedDict,
TypeVar, Union, cast)
from typing import Any, Callable, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypedDict, TypeVar, Union, cast
import torch
import torch.types
......@@ -15,13 +16,13 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
NestedTensors = Union[List[torch.Tensor], torch.Tensor]
NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor]
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors = Union[List[NestedTensors], NestedTensors]
BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
......@@ -53,7 +54,7 @@ class MultiModalInputs(_MultiModalInputsBase):
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t.to(device=device) for t in tensor[0]]
for tensor in tensors]
for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment