"tests/test_image_processing_common.py" did not exist on "01db72abd4859aa64d34fea3ae8cf27d71baee9b"
Unverified Commit c7f076a0 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Adds VIP-llava to transformers (#27932)

* v1

* add-new-model-like

* revert

* fix forward and conversion script

* revert

* fix copies

* fixup

* fix

* Update docs/source/en/index.md

* Apply suggestions from code review

* push

* fix

* fixes here and there

* up

* fixup and fix tests

* Apply suggestions from code review

* add docs

* fixup

* fixes

* docstring

* add docstring

* fixup

* docstring

* fixup

* nit

* docs

* more copies

* fix copies

* nit

* update test
parent 371fb0b7
# coding=utf-8
# Copyright 2023 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VipLlava model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
VIPLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"ybelkada/vip-llava-7b-hf": "https://huggingface.co/llava-hf/vip-llava-7b-hf/resolve/main/config.json",
}
class VipLlavaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VipLlavaForConditionalGeneration`]. It is used to instantiate an
VipLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the VipLlava-9B.
e.g. [ybelkada/vip-llava-7b-hf](https://huggingface.co/ybelkada/vip-llava-7b-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`VipLlavaVisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
projector_layernorm_eps (`float`, *optional*, defaults to 1e-05):
The layer norm epsilon of the projector layernorm
vision_feature_layers (`List[int]`, *optional*, defaults to `[-2, -5, -8, -11, 6]`):
The list of layers to select the vision features from.
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the VipLlava model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`~VipLlavaForConditionalGeneration`]
Example:
```python
>>> from transformers import VipLlavaForConditionalGeneration, VipLlavaConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a VipLlava vipllava-7b style configuration
>>> configuration = VipLlavaConfig(vision_config, text_config)
>>> # Initializing a model from the vipllava-7b style configuration
>>> model = VipLlavaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "vipllava"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
projector_hidden_act="gelu",
projector_layernorm_eps=1e-5,
vision_feature_layers=[-2, -5, -8, -11, 6],
vocab_size=32000,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.projector_layernorm_eps = projector_layernorm_eps
self.vision_feature_layers = vision_feature_layers
self.vocab_size = vocab_size
self.vision_config = vision_config
if isinstance(self.vision_config, dict):
vision_config["model_type"] = (
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
)
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["clip_vision_model"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=336,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
self.vocab_size = self.vocab_size
self.text_config = text_config
if isinstance(self.text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
self.vocab_size = self.text_config.vocab_size
elif text_config is None:
self.text_config = CONFIG_MAPPING["llama"]()
super().__init__(**kwargs)
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AddedToken,
AutoConfig,
AutoTokenizer,
CLIPImageProcessor,
LlavaProcessor,
VipLlavaConfig,
VipLlavaForConditionalGeneration,
)
KEYS_TO_MODIFY_MAPPING = {
"model.vision_tower.": "",
"model.mm_projector": "multi_modal_projector",
"model": "model.model",
"vision_model.model": "vision_model",
"lm_head": "language_model.lm_head",
"model.model": "language_model.model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
"final_linear.0": "linear_1",
"final_linear.2": "linear_2",
"multi_modal_projector.clip_layernorm": "multi_modal_projector.projector_layernorm",
}
# Copied from transformers.models.llava.convert_llava_weights_to_hf.convert_state_dict_to_hf
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def convert_vipllava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id)
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False))
tokenizer.add_special_tokens({"pad_token": "<pad>"})
image_processor = CLIPImageProcessor.from_pretrained(vision_model_id)
processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
config = VipLlavaConfig(text_config=text_config)
config.pad_token_id = 32001
with torch.device("meta"):
model = VipLlavaForConditionalGeneration(config)
# Pad to 64 for performance reasons
pad_shape = 64
state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict_7b.bin")
state_dict = torch.load(state_dict_path, map_location="cpu")
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=True, assign=True)
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
# We add an image token so we resize the model
model.resize_token_embeddings(config.text_config.vocab_size + 2, pad_shape)
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
model.config.vocab_size = model.config.vocab_size + pad_shape
model.config.text_config.vocab_size = model.config.text_config.vocab_size + pad_shape
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--text_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_vipllava_llama_to_hf(
args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id
)
if __name__ == "__main__":
main()
This diff is collapsed.
...@@ -8320,6 +8320,23 @@ class ViltPreTrainedModel(metaclass=DummyObject): ...@@ -8320,6 +8320,23 @@ class ViltPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
VIPLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class VipLlavaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VipLlavaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class VisionEncoderDecoderModel(metaclass=DummyObject): class VisionEncoderDecoderModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch VipLlava model. """
import gc
import unittest
import requests
from transformers import (
AutoProcessor,
VipLlavaConfig,
VipLlavaForConditionalGeneration,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
# Copied from transformers.tests.models.llava.test_modeling_llava.LlavaVisionText2TextModelTester with Llava->VipLlava
class VipLlavaVisionText2TextModelTester:
# Ignore copy
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=0,
projector_hidden_act="gelu",
seq_length=7,
vision_feature_layers=[0, 0, 1, 1, 0],
text_config={
"model_type": "llama",
"seq_length": 7,
"is_training": True,
"use_input_mask": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 0,
},
is_training=True,
vision_config={
"batch_size": 12,
"image_size": 30,
"patch_size": 2,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
):
self.parent = parent
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_layers = vision_feature_layers
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
def get_config(self):
return VipLlavaConfig(
text_config=self.text_config,
vision_config=self.vision_config,
ignore_index=self.ignore_index,
image_token_index=self.image_token_index,
projector_hidden_act=self.projector_hidden_act,
vision_feature_layers=self.vision_feature_layers,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
# we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
# Copied from transformers.tests.models.llava.test_modeling_llava.LlavaForConditionalGenerationModelTest with Llava->VipLlava
class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `VipLlavaForConditionalGeneration`.
"""
all_model_classes = (VipLlavaForConditionalGeneration,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = VipLlavaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=VipLlavaConfig, has_text_modality=False)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch
class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
model_id = "llava-hf/vip-llava-7b-hf"
model = VipLlavaForConditionalGeneration.from_pretrained(model_id, load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nCan you please describe this image?\nASSISTANT:"
inputs = processor(prompt, image, return_tensors="pt").to(torch_device, torch.float16)
outputs = model.generate(**inputs, max_new_tokens=10)
EXPECTED_OUTPUT = "USER: <image> \nCan you please describe this image?\nASSISTANT: The image features a brown and white cat sitting on"
self.assertEqual(processor.decode(outputs[0], skip_special_tokens=True), EXPECTED_OUTPUT)
...@@ -239,6 +239,7 @@ docs/source/en/model_doc/upernet.md ...@@ -239,6 +239,7 @@ docs/source/en/model_doc/upernet.md
docs/source/en/model_doc/van.md docs/source/en/model_doc/van.md
docs/source/en/model_doc/videomae.md docs/source/en/model_doc/videomae.md
docs/source/en/model_doc/vilt.md docs/source/en/model_doc/vilt.md
docs/source/en/model_doc/vipllava.md
docs/source/en/model_doc/vision-encoder-decoder.md docs/source/en/model_doc/vision-encoder-decoder.md
docs/source/en/model_doc/vision-text-dual-encoder.md docs/source/en/model_doc/vision-text-dual-encoder.md
docs/source/en/model_doc/visual_bert.md docs/source/en/model_doc/visual_bert.md
...@@ -847,6 +848,8 @@ src/transformers/models/videomae/configuration_videomae.py ...@@ -847,6 +848,8 @@ src/transformers/models/videomae/configuration_videomae.py
src/transformers/models/videomae/convert_videomae_to_pytorch.py src/transformers/models/videomae/convert_videomae_to_pytorch.py
src/transformers/models/vilt/configuration_vilt.py src/transformers/models/vilt/configuration_vilt.py
src/transformers/models/vilt/convert_vilt_original_to_pytorch.py src/transformers/models/vilt/convert_vilt_original_to_pytorch.py
src/transformers/models/vipllava/configuration_vipllava.py
src/transformers/models/vipllava/modeling_vipllava.py
src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py src/transformers/models/vision_text_dual_encoder/modeling_flax_vision_text_dual_encoder.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment