Unverified Commit c0d8f163 authored by Jungho Christopher Cho's avatar Jungho Christopher Cho Committed by GitHub
Browse files

[Model] SiglipVisionModel ported from transformers (#6942)


Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent cc08fc72
......@@ -65,7 +65,8 @@ def run_phi3v(question):
# PaliGemma
def run_paligemma(question):
prompt = question
# PaliGemma has special prompt format for VQA
prompt = "caption en"
llm = LLM(model="google/paligemma-3b-mix-224")
return llm, prompt
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
from PIL import Image
from torch import nn
from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
......@@ -18,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsVision
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
......@@ -32,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
def get_max_paligemma_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
text_config = hf_config.text_config
return text_config.num_image_tokens
def dummy_seq_data_for_paligemma(
hf_config: PaliGemmaConfig,
seq_len: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = hf_config.text_config.num_image_tokens
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_paligemma(
hf_config: SiglipVisionConfig,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
vision_config = hf_config.vision_config
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return get_max_siglip_image_tokens(vision_config)
def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(PaliGemmaConfig)
vision_config = hf_config.vision_config
seq_data = dummy_seq_data_for_paligemma(
hf_config,
seq_data = dummy_seq_data_for_siglip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
)
mm_data = dummy_image_for_paligemma(vision_config)
mm_data = dummy_image_for_siglip(vision_config)
return seq_data, mm_data
......@@ -208,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
data=self._validate_pixel_values(pixel_values),
)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
def _image_pixels_to_features(
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
output_hidden_states=True)
selected_image_features = image_outputs.last_hidden_state
image_features = vision_tower(pixel_values.to(dtype=target_dtype))
return selected_image_features
return image_features
def _process_image_pixels(
self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
self,
inputs: PaliGemmaImagePixelInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
return self._image_pixels_to_features(self.vision_tower, pixel_values)
return self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
def _process_image_input(
self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
self,
image_input: PaliGemmaImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
image_features = self._process_image_pixels(image_input, )
return self.multi_modal_projector(image_features)
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment