"vscode:/vscode.git/clone" did not exist on "2263d44b688902aa3fd384ecdd7e3db3460b01e0"
Unverified Commit daed30c4 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix feature size calculation for LLaVA-NeXT (#6982)

parent 2f4e108f
from typing import List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type, overload
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers import AutoTokenizer
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
......@@ -50,6 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
......@@ -62,13 +63,55 @@ def run_test(
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
@overload
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
sizes: List[Tuple[int, int]],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
...
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: Optional[List[float]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
if size_factors is not None:
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
elif sizes is not None:
inputs_per_image = [(
[prompt for _ in sizes],
[image.resize(size) for size in sizes],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
else:
raise ValueError("You must provide either `size_factors` or `sizes`")
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
......@@ -150,15 +193,24 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
)
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
(183, 488, 776)])
def test_image_feature_size(height_and_width_and_result):
# Avoid initializing CUDA too early in distributed tests
from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)
height, width, result = height_and_width_and_result
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
assert get_llava_next_image_feature_size(config,
input_height=height,
input_width=width) == result
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
[[(1669, 2560), (2560, 1669), (183, 488), (488, 183)]],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes,
dtype, max_tokens, num_logprobs) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
sizes=sizes,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
......@@ -169,7 +169,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
raise TypeError(f"Invalid image type: {type(image_data)}")
# process prompts
prompt = llm_inputs["prompt"]
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
......
......@@ -20,7 +20,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
class InternVLImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
......@@ -193,7 +193,7 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
tokenizer = cached_get_tokenizer(model_config.tokenizer,
trust_remote_code=True)
prompt = llm_inputs["prompt"]
prompt = llm_inputs.get("prompt")
prompt_token_ids = llm_inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
......
......@@ -21,7 +21,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
......@@ -43,7 +43,7 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
......@@ -62,31 +62,26 @@ class LlavaNextImagePixelInputs(TypedDict):
LlavaNextImageInputs = LlavaNextImagePixelInputs
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
current_height = torch.tensor(current_height).to("cuda")
current_width = torch.tensor(current_width).to("cuda")
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if aspect_ratio > current_aspect_ratio:
scale_factor = current_width / width
new_height = int(height * scale_factor)
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height -= padding * 2
else:
scale_factor = current_height / height
new_width = int(width * scale_factor)
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width -= padding * 2
......@@ -95,7 +90,7 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
......@@ -111,9 +106,7 @@ def get_llava_next_image_feature_size(
)
base_feature_size = num_patches * num_patches
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
......@@ -349,11 +342,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# Move to CPU to avoid floating-point errors
orig_height, orig_width = image_size.tolist()
# image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
(orig_height, orig_width),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
......@@ -365,7 +359,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
image_size)
(orig_height, orig_width))
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
......@@ -398,7 +392,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_pixels(
self,
inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors:
) -> Union[torch.Tensor, List[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["data"]
......@@ -425,7 +419,9 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
]
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
self,
image_input: LlavaNextImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
......
......@@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: BatchedTensors
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment