Unverified Commit 6287537a authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] LLaVA model refactor (#4910)

parent b57e6c59
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
...@@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor, ...@@ -67,6 +67,21 @@ def _merge_vision_embeddings(input_ids: torch.Tensor,
return inputs_embeds return inputs_embeds
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class LlavaImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, image_feature_size, hidden_size)"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
class LlavaForConditionalGeneration(VisionLanguageModelBase): class LlavaForConditionalGeneration(VisionLanguageModelBase):
def __init__(self, def __init__(self,
...@@ -102,6 +117,90 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -102,6 +117,90 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
config.vocab_size, logit_scale) config.vocab_size, logit_scale)
self.sampler = Sampler() self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if data is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image pixel vector should be a tensor, "
f"but received type: {type(data)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
)
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image feature vector should be a tensor, "
f"but received type: {type(data)}")
return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
)
return None
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
def _process_image_pixels(self,
inputs: LlavaImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
return self.multi_modal_projector(image_features)
def forward(self, def forward(self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -144,42 +243,20 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -144,42 +243,20 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
For PIXEL_VALUES, expecting [1, 3, 336, 336]. For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024]. For IMAGE_FEATURES, expecting [1, 576, 1024].
""" """
if image_input is not None: parsed_image_input = self._parse_and_validate_image_input(image_input)
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]): if parsed_image_input is not None:
raise ValueError( vision_embeddings = self._process_image_input(parsed_image_input)
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
if self.vision_tower is not None:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = self.vision_tower(image_input,
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if self.config.vision_feature_select_strategy == "default":
image_features = image_features[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
image_features = image_features
else:
raise ValueError(
f"Unexpected select feature strategy: "
f"{self.config.vision_feature_select_strategy}")
else:
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids) inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = _merge_vision_embeddings( inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id) self.vision_language_config.image_token_id)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model(input_ids, hidden_states = self.language_model(input_ids,
positions, positions,
kv_caches, kv_caches,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment