Unverified Commit 44b5506d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Llava`] Add Llava to transformers (#27662)

* add model like

* logits match

* minor fixes

* fixes

* up

* up

* add todo

* llava processor

* keep the processor simple

* add conversion script

* fixup

* fix copies

* up

* add to index

* fix config + logits

* fix

* refactor

* more refactor

* more refactor

* fix copies

* add authors

* v1 tests

* add `LlavaProcessor` in init

* remove unneeded import

* up

* up

* docs

* up

* fix CI

* fix CI

* add attention  mask in test

* make fixup

* remove the vision model

* that' s the dirty way to do it

* nits

* nits

* updates

* add more tests

* add input tests

* fixup

* more styling

* nits

* updates amd cleanup

* fixup the generation expected results

* fix the testing script

* some cleanup and simplification which does not work yet but almost there!

* make correct dispatch operations

* vectorize works for batch of images and text

* last todos

* nits

* update test and modeling code

* remove useless function for now

* fix few issues

* fix generation

* some nits

* add bakllava

* nits

* remove duplicated code

* finis merge

* cleanup

* missed this line

* fill the todos

* add left padding offset

* add left and rignt padding logic

* bool to properly index

* make sure

* more cleanups

* batch is fixed 😉



* add correct device for tensor creation

* fix some dtype missmatch

* ruff

* update conversion script

* Update src/transformers/__init__.py

* fa 2 support + fix conversion script

* more

* correct reshaping

* fix test dict

* fix copies by ignoring

* fix nit

* skip clip vision model

* fixup

* fixup

* LlavaForVisionText2Text -> LlavaForCausalLM

* update

* fix

* raise correct errors

* fix

* docs

* nuke for now

* nits here and there

* fixup

* fix remaining tests

* update LlavaForConditionalGeneration instead of CausalLM

* fixups

* pipeline support

* slow and piepline tests

* supports batch

* nits

* cleanup

* fix first integration tests

* add pad token where needed

* correct etsts

* fixups

* update pipeline testr

* fix quality

* nits

* revert unneeded change

* nit

* use BatchFeature

* from ...feature_extraction_utils import BatchFeature

* nits

* nits

* properly update

* more f*** nits

* fix copies

* comment

* keep slow test slow

* Update src/transformers/models/llava/processing_llava.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* add piepline example

* add pixel values in docstrign

* update pr doctest

* fix

* fix slow tests

* remove hack

* fixup

* small note

* forward contrib credits from PR25789

* forward contrib credits from original implementation and work

* add arthur

* Update src/transformers/models/llava/processing_llava.py
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>

* update docstring

* nit

* move to not doctested because of timeout issues

* fixup

* add description

* more

* fix-copies

* fix docs

* add beam search

* add more comments

* add typehints on processor

* add speedup plot

* update slow tests and docs

* push test

* push batched test

* fix batched generation with different number of images

* remove benchmark due to a bug

* fix test

* fix copies

* add gcolab demo

---------
Co-authored-by: default avatarArthur Zucker <arthur.zucker@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarshauray8 <shauray8@users.noreply.github.com>
Co-authored-by: default avatarhaotian-liu <haotian-liu@users.noreply.github.com>
Co-authored-by: default avatarLysandre Debut <hi@lysand.re>
parent 0410a29a
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for Llava.
"""
import warnings
from typing import Callable, List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
class LlavaProcessor(ProcessorMixin):
r"""
Constructs a Llava processor which wraps a Llava image processor and a Llava tokenizer into a single processor.
[`LlavaProcessor`] offers all the functionalities of [`CLIPImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
Args:
image_processor ([`CLIPImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "CLIPImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.__init__
def __init__(self, image_processor=None, tokenizer=None, **kwargs):
feature_extractor = None
if "feature_extractor" in kwargs:
warnings.warn(
"The `feature_extractor` argument is deprecated and will be removed in v5, use `image_processor`"
" instead.",
FutureWarning,
)
feature_extractor = kwargs.pop("feature_extractor")
image_processor = image_processor if image_processor is not None else feature_extractor
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
raise ValueError("You need to specify a `tokenizer`.")
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
transform: Callable = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
transform (`Callable`, *optional*):
A custom transform function that accepts a single image can be passed for training. For example,
`torchvision.Compose` can be used to compose multiple functions. If `None` a preset inference-specific
set of transforms will be applied to the images
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
if images is not None:
pixel_values = self.image_processor(images, transform=transform, return_tensors=return_tensors)[
"pixel_values"
]
else:
pixel_values = None
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
return BatchFeature(data={**text_inputs, "pixel_values": pixel_values})
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
...@@ -4641,6 +4641,30 @@ class LlamaPreTrainedModel(metaclass=DummyObject): ...@@ -4641,6 +4641,30 @@ class LlamaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
LLAVA_PRETRAINED_MODEL_ARCHIVE_LIST = None
class LlavaForConditionalGeneration(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlavaPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LlavaProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Llava model. """
import gc
import unittest
import requests
from transformers import (
AutoProcessor,
LlavaConfig,
LlavaForConditionalGeneration,
is_torch_available,
is_vision_available,
)
from transformers.testing_utils import require_bitsandbytes, require_torch, slow, torch_device
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available():
import torch
else:
is_torch_greater_or_equal_than_2_0 = False
if is_vision_available():
from PIL import Image
class LlavaVisionText2TextModelTester:
def __init__(
self,
parent,
ignore_index=-100,
image_token_index=0,
projector_hidden_act="gelu",
seq_length=7,
vision_feature_select_strategy="default",
vision_feature_layer=-1,
text_config={
"model_type": "llama",
"seq_length": 7,
"is_training": True,
"use_input_mask": True,
"use_token_type_ids": False,
"use_labels": True,
"vocab_size": 99,
"hidden_size": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 16,
"type_sequence_label_size": 2,
"initializer_range": 0.02,
"num_labels": 3,
"num_choices": 4,
"pad_token_id": 0,
},
is_training=True,
vision_config={
"batch_size": 12,
"image_size": 30,
"patch_size": 2,
"num_channels": 3,
"is_training": True,
"hidden_size": 32,
"projection_dim": 32,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"intermediate_size": 37,
"dropout": 0.1,
"attention_dropout": 0.1,
"initializer_range": 0.02,
},
):
self.parent = parent
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.text_config = text_config
self.vision_config = vision_config
self.seq_length = seq_length
self.num_hidden_layers = text_config["num_hidden_layers"]
self.vocab_size = text_config["vocab_size"]
self.hidden_size = text_config["hidden_size"]
self.num_attention_heads = text_config["num_attention_heads"]
self.is_training = is_training
self.batch_size = 3
self.num_channels = 3
self.image_size = 336
self.encoder_seq_length = 231
def get_config(self):
return LlavaConfig(
text_config=self.text_config,
vision_config=self.vision_config,
ignore_index=self.ignore_index,
image_token_index=self.image_token_index,
projector_hidden_act=self.projector_hidden_act,
vision_feature_select_strategy=self.vision_feature_select_strategy,
vision_feature_layer=self.vision_feature_layer,
)
def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
[
self.batch_size,
self.vision_config["num_channels"],
self.vision_config["image_size"],
self.vision_config["image_size"],
]
)
config = self.get_config()
return config, pixel_values
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
attention_mask = input_ids.ne(1).to(torch_device)
# we are giving 3 images let's make sure we pass in 3 image tokens
input_ids[:, 1] = config.image_token_index
inputs_dict = {
"pixel_values": pixel_values,
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class LlavaForConditionalGenerationModelTest(ModelTesterMixin, unittest.TestCase):
"""
Model tester for `LlavaForConditionalGeneration`.
"""
all_model_classes = (LlavaForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_pruning = False
test_resize_embeddings = True
test_head_masking = False
def setUp(self):
self.model_tester = LlavaVisionText2TextModelTester(self)
self.config_tester = ConfigTester(self, config_class=LlavaConfig, has_text_modality=False)
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@require_torch
class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
def setUp(self):
self.processor = AutoProcessor.from_pretrained("llava-hf/bakLlava-v1-hf")
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
@slow
@require_bitsandbytes
def test_small_model_integration_test(self):
# Let' s make sure we test the preprocessing to replace what is used
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True)
prompt = "<image>\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = self.processor(prompt, raw_image, return_tensors="pt")
EXPECTED_INPUT_IDS = torch.tensor([[1, 32000, 28705, 13, 11123, 28747, 1824, 460, 272, 1722,315, 1023, 347, 13831, 925, 684, 739, 315, 3251, 456,1633, 28804, 13, 4816, 8048, 12738, 28747]]) # fmt: skip
self.assertTrue(torch.equal(inputs["input_ids"], EXPECTED_INPUT_IDS))
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = "\nUSER: What are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, there are a few things one should be cautious about. Firstly," # fmt: skip
self.assertEqual(
self.processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id)
prompt = "USER: <image>\nWhat are the things I should be cautious about when I visit this place?\nASSISTANT:"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors="pt").to(torch_device, torch.float16)
output = model.generate(**inputs, max_new_tokens=900, do_sample=False)
EXPECTED_DECODED_TEXT = "USER: \nWhat are the things I should be cautious about when I visit this place?\nASSISTANT: When visiting this place, which is a pier or dock extending over a body of water, there are a few things to be cautious about. First, be aware of the weather conditions, as sudden changes in weather can make the pier unsafe to walk on. Second, be mindful of the water depth and any potential hazards, such as submerged rocks or debris, that could cause accidents or injuries. Additionally, be cautious of the presence of wildlife, such as birds or fish, and avoid disturbing their natural habitats. Lastly, be aware of any local regulations or guidelines for the use of the pier, as some areas may be restricted or prohibited for certain activities." # fmt: skip
self.assertEqual(
processor.decode(output[0], skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
@slow
@require_bitsandbytes
def test_small_model_integration_test_llama_batched(self):
# Let' s make sure we test the preprocessing to replace what is used
model_id = "llava-hf/llava-1.5-7b-hf"
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", load_in_4bit=True)
processor = AutoProcessor.from_pretrained(model_id, pad_token="<pad>")
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['USER: \nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT: the water is calm and clear\n\nThe image shows a wooden pier on a lake, with a', 'USER: \nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: \nAnd this?\nASSISTANT: A cat sleeping on a bed.'] # fmt: skip
self.assertEqual(processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
@slow
@require_bitsandbytes
def test_small_model_integration_test_batch(self):
# Let' s make sure we test the preprocessing to replace what is used
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/bakLlava-v1-hf", load_in_4bit=True)
# The first batch is longer in terms of text, but only has 1 image. The second batch will be padded in text, but the first will be padded because images take more space!.
prompts = [
"USER: <image>\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nASSISTANT:",
"USER: <image>\nWhat is this?\nASSISTANT: Two cats lying on a bed!\nUSER: <image>\nAnd this?\nASSISTANT:",
]
image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
inputs = self.processor(prompts, images=[image1, image2, image1], return_tensors="pt", padding=True)
output = model.generate(**inputs, max_new_tokens=20)
EXPECTED_DECODED_TEXT = ['\nUSER: What are the things I should be cautious about when I visit this place? What should I bring with me\nASSISTANT: When visiting this place, bring a camera, Park rules may apply, Per person, Sunrise over', '\nUSER: What is this?\nASSISTANT: Two cats lying on a bed!\nUSER: And this? a dock on a lake with two cats on it (Photo credit: Jocelyn R.'] # fmt: skip
self.assertEqual(self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT)
...@@ -252,3 +252,24 @@ class ImageToTextPipelineTests(unittest.TestCase): ...@@ -252,3 +252,24 @@ class ImageToTextPipelineTests(unittest.TestCase):
[{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}], [{"generated_text": "a cat laying on a blanket next to a cat laying on a bed "}],
], ],
) )
@slow
@require_torch
def test_conditional_generation_llava(self):
pipe = pipeline("image-to-text", model="llava-hf/bakLlava-v1-hf")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = (
"<image>\nUSER: What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud?\nASSISTANT:"
)
outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
self.assertEqual(
outputs,
[
{
"generated_text": "<image> \nUSER: What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud?\nASSISTANT: Lava"
}
],
)
...@@ -673,6 +673,7 @@ MODELS_NOT_IN_README = [ ...@@ -673,6 +673,7 @@ MODELS_NOT_IN_README = [
"TimmBackbone", "TimmBackbone",
"Vision Encoder decoder", "Vision Encoder decoder",
"VisionTextDualEncoder", "VisionTextDualEncoder",
"CLIPVisionModel",
] ]
# Template for new entries to add in the main README when we have missing models. # Template for new entries to add in the main README when we have missing models.
......
...@@ -171,6 +171,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = { ...@@ -171,6 +171,7 @@ MODEL_NAMES_WITH_SAME_CONFIG = {
"XLS-R": "Wav2Vec2", "XLS-R": "Wav2Vec2",
"XLSR-Wav2Vec2": "Wav2Vec2", "XLSR-Wav2Vec2": "Wav2Vec2",
} }
MODEL_NAMES_TO_IGNORE = ["CLIPVisionModel"]
def get_model_table_from_auto_modules() -> str: def get_model_table_from_auto_modules() -> str:
...@@ -243,6 +244,8 @@ def get_model_table_from_auto_modules() -> str: ...@@ -243,6 +244,8 @@ def get_model_table_from_auto_modules() -> str:
check = {True: "✅", False: "❌"} check = {True: "✅", False: "❌"}
for name in model_names: for name in model_names:
if name in MODEL_NAMES_TO_IGNORE:
continue
if name in MODEL_NAMES_WITH_SAME_CONFIG.keys(): if name in MODEL_NAMES_WITH_SAME_CONFIG.keys():
prefix = model_name_to_prefix[MODEL_NAMES_WITH_SAME_CONFIG[name]] prefix = model_name_to_prefix[MODEL_NAMES_WITH_SAME_CONFIG[name]]
else: else:
......
...@@ -146,6 +146,7 @@ docs/source/en/model_doc/levit.md ...@@ -146,6 +146,7 @@ docs/source/en/model_doc/levit.md
docs/source/en/model_doc/lilt.md docs/source/en/model_doc/lilt.md
docs/source/en/model_doc/llama.md docs/source/en/model_doc/llama.md
docs/source/en/model_doc/llama2.md docs/source/en/model_doc/llama2.md
docs/source/en/model_doc/llava.md
docs/source/en/model_doc/longformer.md docs/source/en/model_doc/longformer.md
docs/source/en/model_doc/longt5.md docs/source/en/model_doc/longt5.md
docs/source/en/model_doc/luke.md docs/source/en/model_doc/luke.md
...@@ -294,7 +295,7 @@ docs/source/en/serialization.md ...@@ -294,7 +295,7 @@ docs/source/en/serialization.md
docs/source/en/tasks/asr.md docs/source/en/tasks/asr.md
docs/source/en/tasks/audio_classification.md docs/source/en/tasks/audio_classification.md
docs/source/en/tasks/document_question_answering.md docs/source/en/tasks/document_question_answering.md
docs/source/en/tasks/idefics.md # causes other tests to fail docs/source/en/tasks/idefics.md
docs/source/en/tasks/image_captioning.md docs/source/en/tasks/image_captioning.md
docs/source/en/tasks/image_classification.md docs/source/en/tasks/image_classification.md
docs/source/en/tasks/language_modeling.md docs/source/en/tasks/language_modeling.md
...@@ -432,7 +433,7 @@ src/transformers/models/blip/modeling_blip_text.py ...@@ -432,7 +433,7 @@ src/transformers/models/blip/modeling_blip_text.py
src/transformers/models/blip/modeling_tf_blip_text.py src/transformers/models/blip/modeling_tf_blip_text.py
src/transformers/models/blip_2/configuration_blip_2.py src/transformers/models/blip_2/configuration_blip_2.py
src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py src/transformers/models/blip_2/convert_blip_2_original_to_pytorch.py
src/transformers/models/blip_2/modeling_blip_2.py # causes other tests to fail src/transformers/models/blip_2/modeling_blip_2.py
src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py src/transformers/models/bloom/convert_bloom_original_checkpoint_to_pytorch.py
src/transformers/models/bloom/modeling_bloom.py src/transformers/models/bloom/modeling_bloom.py
src/transformers/models/bloom/modeling_flax_bloom.py src/transformers/models/bloom/modeling_flax_bloom.py
...@@ -634,6 +635,8 @@ src/transformers/models/lilt/configuration_lilt.py ...@@ -634,6 +635,8 @@ src/transformers/models/lilt/configuration_lilt.py
src/transformers/models/llama/configuration_llama.py src/transformers/models/llama/configuration_llama.py
src/transformers/models/llama/convert_llama_weights_to_hf.py src/transformers/models/llama/convert_llama_weights_to_hf.py
src/transformers/models/llama/modeling_llama.py src/transformers/models/llama/modeling_llama.py
src/transformers/models/llava/configuration_llava.py
src/transformers/models/llava/modeling_llava.py
src/transformers/models/longformer/configuration_longformer.py src/transformers/models/longformer/configuration_longformer.py
src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py
src/transformers/models/longt5/configuration_longt5.py src/transformers/models/longt5/configuration_longt5.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment