Unverified Commit edf309eb authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[VLM] Support multimodal inputs for Florence-2 models (#13320)

parent 788f284b
......@@ -715,6 +715,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Florence2ForConditionalGeneration`
* Florence-2
* T + I
* `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc.
*
*
*
- * `FuyuForCausalLM`
* Fuyu
* T + I
......
# SPDX-License-Identifier: Apache-2.0
'''
"""
Demonstrate prompting of text-to-text
encoder/decoder models, specifically Florence-2
'''
"""
# TODO(Isotr0py):
# Move to offline_inference/vision_language.py
# after porting vision backbone
from vllm import LLM, SamplingParams
dtype = "float"
from vllm.assets.image import ImageAsset
# Create a Florence-2 encoder/decoder model instance
llm = LLM(
model="microsoft/Florence-2-base",
tokenizer="facebook/bart-base",
dtype=dtype,
model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
trust_remote_code=True,
)
prompts = [
"<CAPTION>", "<DETAILED_CAPTION>", "<MORE_DETAILED_CAPTION>",
"<CAPTION_TO_PHRASE_GROUNDING>", "<OD>", "<DENSE_REGION_CAPTION>",
"<REGION_PROPOSAL>", "<OCR>", "<OCR_WITH_REGION>"
{ # implicit prompt with task token
"prompt": "<DETAILED_CAPTION>",
"multi_modal_data": {
"image": ImageAsset("stop_sign").pil_image
},
},
{ # explicit encoder/decoder prompt
"encoder_prompt": {
"prompt": "Describe in detail what is shown in the image.",
"multi_modal_data": {
"image": ImageAsset("cherry_blossom").pil_image
},
},
"decoder_prompt": "",
},
]
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0,
top_p=1.0,
min_tokens=0,
max_tokens=20,
max_tokens=128,
)
# Generate output tokens from the prompts. The output is a list of
......@@ -38,9 +49,5 @@ outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
encoder_prompt = output.encoder_prompt
generated_text = output.outputs[0].text
print(f"Encoder prompt: {encoder_prompt!r}, "
f"Decoder prompt: {prompt!r}, "
f"Generated text: {generated_text!r}")
print(f"Generated text: {generated_text!r}")
......@@ -82,6 +82,22 @@ def run_deepseek_vl2(question: str, modality: str):
return llm, prompt, stop_token_ids
# Florence2
def run_florence2(question: str, modality: str):
assert modality == "image"
llm = LLM(model="microsoft/Florence-2-large",
tokenizer="facebook/bart-large",
max_num_seqs=8,
trust_remote_code=True,
dtype="bfloat16",
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
prompt = "<MORE_DETAILED_CAPTION>"
stop_token_ids = None
return llm, prompt, stop_token_ids
# Fuyu
def run_fuyu(question: str, modality: str):
assert modality == "image"
......@@ -571,6 +587,7 @@ model_example_map = {
"blip-2": run_blip2,
"chameleon": run_chameleon,
"deepseek_vl_v2": run_deepseek_vl2,
"florence2": run_florence2,
"fuyu": run_fuyu,
"glm4v": run_glm4v,
"h2ovl_chat": run_h2ovl,
......
......@@ -600,8 +600,8 @@ class HfRunner:
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]
encoder_input_ids = self.wrap_device(
self.processor(**processor_kwargs).input_ids,
encoder_inputs = self.wrap_device(
self.processor(**processor_kwargs),
device=self.model.device.type,
)
......@@ -615,13 +615,13 @@ class HfRunner:
)
output = self.model.generate(
encoder_input_ids,
decoder_input_ids=decoder_input_ids,
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
**encoder_inputs,
**kwargs,
)
......
......@@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
MODEL_NAME = "fixie-ai/ultravox-v0_4"
AudioTuple = Tuple[np.ndarray, int]
......@@ -187,7 +187,7 @@ def run_multi_audio_test(
@pytest.mark.core_model
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [
......
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import List, Optional, Tuple, Type
from typing import Optional, Type
import pytest
from PIL import Image
from vllm.inputs.data import ExplicitEncoderDecoderPrompt
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ...utils import check_logprobs_close
Florence2Prompt = partial(ExplicitEncoderDecoderPrompt,
decoder_prompt=None,
mm_processor_kwargs=None)
MODELS = ["microsoft/Florence-2-base"]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
TOKENIZER = "facebook/bart-base"
PROMPTS = [
Florence2Prompt(encoder_prompt="<CAPTION>"),
Florence2Prompt(encoder_prompt="<DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<MORE_DETAILED_CAPTION>"),
Florence2Prompt(encoder_prompt="<CAPTION_TO_PHRASE_GROUNDING>"),
Florence2Prompt(encoder_prompt="<DENSE_REGION_CAPTION>"),
Florence2Prompt(encoder_prompt="<REGION_PROPOSAL>"),
Florence2Prompt(encoder_prompt="<OCR_WITH_REGION>"),
Florence2Prompt(encoder_prompt="<OCR>"),
Florence2Prompt(encoder_prompt="<OD>"),
]
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<CAPTION>", # special task token
"cherry_blossom":
"Describe in detail what is shown in the image.",
})
def get_hf_images_prompts(
prompts_: list[ExplicitEncoderDecoderPrompt[str, TextPrompt]],
) -> tuple[list[ExplicitEncoderDecoderPrompt[str, str]], list[Image.Image]]:
prompts, images = [], []
for prompt in prompts_:
encoder_prompt = prompt["encoder_prompt"]
prompts.append(
ExplicitEncoderDecoderPrompt(
encoder_prompt=encoder_prompt["prompt"],
decoder_prompt=None,
))
images.append(encoder_prompt["multi_modal_data"]["image"])
return prompts, images
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]], ):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output
def hf_to_vllm_output(hf_output: tuple[list[int], str,
Optional[SampleLogprobs]]):
"""Sanitize hf output to be comparable with vllm output."""
output_ids, output_str, out_logprobs = hf_output
hf_output_str = "</s><s>" + output_str + "</s>"
output_str = output_str.replace("</s>", "").replace("<s>", "")
output_ids = [ids for ids in output_ids if ids not in [0, 2]]
return output_ids, hf_output_str, out_logprobs
return output_ids, output_str, out_logprobs
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
prompts: List[ExplicitEncoderDecoderPrompt],
inputs: list[list[ExplicitEncoderDecoderPrompt]],
model: str,
*,
dtype: str,
......@@ -56,46 +63,76 @@ def run_test(
distributed_executor_backend: Optional[str] = None,
) -> None:
with vllm_runner(model,
max_num_seqs=8,
tokenizer_name=TOKENIZER,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs)
vllm_outputs_per_case = [
vllm_model.generate_encoder_decoder_greedy_logprobs(
prompts, max_tokens, num_logprobs=num_logprobs)
for prompts in inputs
]
hf_inputs = [get_hf_images_prompts(prompts) for prompts in inputs]
# Florence-2 processors require image inputs
dummy_image = Image.new(mode="RGB", size=(2, 2))
with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.lm_head
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs,
images=[dummy_image] * len(prompts),
))
hf_outputs_per_case = [
hf_model.generate_encoder_decoder_greedy_logprobs_limit(
prompts, max_tokens, num_logprobs=num_logprobs, images=images)
for prompts, images in hf_inputs
]
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
vllm_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs
],
outputs_0_lst=[hf_to_vllm_output(output) for output in hf_outputs],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.core_model
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, model, dtype, max_tokens,
num_logprobs) -> None:
def test_models(hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets, model: str,
size_factors: list[int], dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [[
ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt=prompt,
multi_modal_data={"image": rescale_image_size(image, factor)}),
decoder_prompt=None,
) for factor in size_factors
] for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
run_test(
hf_runner,
vllm_runner,
PROMPTS,
inputs_per_image,
model,
dtype=dtype,
max_tokens=max_tokens,
......
......@@ -29,8 +29,8 @@ def _test_processing_correctness(
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
tokenizer=model_info.tokenizer or model_id,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
seed=0,
dtype="float16",
......@@ -151,6 +151,7 @@ def _test_processing_correctness(
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"microsoft/Florence-2-base",
"adept/fuyu-8b",
"THUDM/glm-4v-9b",
"h2oai/h2ovl-mississippi-800m",
......
......@@ -193,11 +193,6 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
# [Encoder-decoder]
"BartModel": _HfExamplesInfo("facebook/bart-base"),
"BartForConditionalGeneration": _HfExamplesInfo("facebook/bart-large-cnn"),
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
}
_EMBEDDING_EXAMPLE_MODELS = {
......@@ -288,6 +283,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"v0.5": "fixie-ai/ultravox-v0_5-llama-3_2-1b"}, # noqa: E501
trust_remote_code=True),
# [Encoder-decoder]
# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}
......
......@@ -588,8 +588,12 @@ class BartEncoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
......@@ -602,6 +606,7 @@ class BartEncoder(nn.Module):
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(positions)
......@@ -661,9 +666,13 @@ class BartDecoder(nn.Module):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(self, decoder_input_ids: torch.Tensor,
def forward(
self,
decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor]) -> torch.Tensor:
encoder_hidden_states: Optional[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_input_ids
......@@ -677,8 +686,10 @@ class BartDecoder(nn.Module):
Returns:
Decoder output torch.Tensor
"""
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(decoder_input_ids)
else:
decoder_positions = inputs_embeds[:, -1]
# embed positions
embed_pos = self.embed_positions(decoder_positions)
......
This diff is collapsed.
......@@ -105,7 +105,6 @@ _TEXT_GENERATION_MODELS = {
# [Encoder-decoder]
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
}
_EMBEDDING_MODELS = {
......@@ -182,6 +181,7 @@ _MULTIMODAL_MODELS = {
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
"Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501
"MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501
"WhisperForConditionalGeneration": ("whisper", "WhisperForConditionalGeneration"), # noqa: E501
}
......
......@@ -1303,6 +1303,14 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""
raise NotImplementedError
def create_decoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
"""Create input prompt for the decoder."""
return prompt
def apply(
self,
prompt: Union[str, list[int]],
......@@ -1323,17 +1331,15 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs,
)
# We assumed the decoder prompt text is copied from
# the original encoder prompt without extra process
tokenizer = self.info.get_tokenizer()
if isinstance(prompt, str):
decoder_prompt = prompt
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str):
decoder_prompt_ids = encode_tokens(tokenizer,
prompt,
decoder_prompt,
add_special_tokens=False)
else:
decoder_prompt = decode_tokens(tokenizer, prompt)
decoder_prompt_ids = prompt
decoder_prompt_ids = decoder_prompt
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
......
......@@ -204,9 +204,11 @@ class MultiModalProfiler(Generic[_I]):
"and/or reduce `mm_counts`.", seq_len, total_len,
total_placeholders_by_modality)
num_tokens_to_pad = max(total_len, seq_len) - total_len
prompt_token_ids.extend([0] * num_tokens_to_pad)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=None,
multi_modal_placeholders=None,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment