Unverified Commit bd9f4d79 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Add Video Llava (#29733)



* add model draft

* update docstring

* add tests

* support image and video as input

* update for better handling of mixed input and clean-up a bit

* bug when mixed inputs & add tests

* Update README.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Merge remote-tracking branch 'upstream/main' into video_llava

* link to abstract of paper in README

* fix test

* fix-copies

* make tests happy

* skip docstest for now

* do not run doctest for now

* Update src/transformers/models/video_llava/processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/video_llava/test_modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* address review comments

* failing tests

* Fix vocab_size in common tests for VLMs

* codestyle

* Update src/transformers/models/video_llava/configuration_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/configuration_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/video_llava.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/video_llava.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/image_processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/video_llava.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/processing_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/video_llava/test_modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/video_llava/test_modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/video_llava/test_modeling_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* PR suggestions

* fix-copies

* Update src/transformers/models/video_llava/configuration_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/video_llava/configuration_video_llava.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add full example in docs

* clean-up with new model-id

* [run-slow] video_llava

* update docstring

* [run-slow] video_llava

* remove all achive maps

* fix some tests

* test was supposed to be skipped for llava :)

---------
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent b8aee2e9
...@@ -288,7 +288,6 @@ Suivez les pages d'installation de Flax, PyTorch ou TensorFlow pour voir comment ...@@ -288,7 +288,6 @@ Suivez les pages d'installation de Flax, PyTorch ou TensorFlow pour voir comment
Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen) Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=https://huggingface.co/api/shields/models&color=brightgreen)
🤗 Transformers fournit actuellement les architectures suivantes: consultez [ici](https://huggingface.co/docs/transformers/model_summary) pour un résumé global de chacune d'entre elles. 🤗 Transformers fournit actuellement les architectures suivantes: consultez [ici](https://huggingface.co/docs/transformers/model_summary) pour un résumé global de chacune d'entre elles.
Pour vérifier si chaque modèle a une implémentation en Flax, PyTorch ou TensorFlow, ou s'il a un tokenizer associé pris en charge par la bibliothèque 🤗 Tokenizers, consultez [ce tableau](https://huggingface.co/docs/transformers/index#supported-frameworks). Pour vérifier si chaque modèle a une implémentation en Flax, PyTorch ou TensorFlow, ou s'il a un tokenizer associé pris en charge par la bibliothèque 🤗 Tokenizers, consultez [ce tableau](https://huggingface.co/docs/transformers/index#supported-frameworks).
......
...@@ -293,7 +293,6 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా ...@@ -293,7 +293,6 @@ Flax, PyTorch లేదా TensorFlow యొక్క ఇన్‌స్టా
🤗 ట్రాన్స్‌ఫార్మర్లు ప్రస్తుతం కింది ఆర్కిటెక్చర్‌లను అందజేస్తున్నాయి: వాటిలో ప్రతి ఒక్కటి ఉన్నత స్థాయి సారాంశం కోసం [ఇక్కడ](https://huggingface.co/docs/transformers/model_summary) చూడండి. 🤗 ట్రాన్స్‌ఫార్మర్లు ప్రస్తుతం కింది ఆర్కిటెక్చర్‌లను అందజేస్తున్నాయి: వాటిలో ప్రతి ఒక్కటి ఉన్నత స్థాయి సారాంశం కోసం [ఇక్కడ](https://huggingface.co/docs/transformers/model_summary) చూడండి.
ఈ అమలులు అనేక డేటాసెట్‌లలో పరీక్షించబడ్డాయి (ఉదాహరణ స్క్రిప్ట్‌లను చూడండి) మరియు అసలైన అమలుల పనితీరుతో సరిపోలాలి. మీరు [డాక్యుమెంటేషన్](https://github.com/huggingface/transformers/tree/main/examples) యొక్క ఉదాహరణల విభాగంలో పనితీరుపై మరిన్ని వివరాలను కనుగొనవచ్చు. ఈ అమలులు అనేక డేటాసెట్‌లలో పరీక్షించబడ్డాయి (ఉదాహరణ స్క్రిప్ట్‌లను చూడండి) మరియు అసలైన అమలుల పనితీరుతో సరిపోలాలి. మీరు [డాక్యుమెంటేషన్](https://github.com/huggingface/transformers/tree/main/examples) యొక్క ఉదాహరణల విభాగంలో పనితీరుపై మరిన్ని వివరాలను కనుగొనవచ్చు.
## ఇంకా నేర్చుకో ## ఇంకా నేర్చుకో
......
...@@ -806,6 +806,8 @@ ...@@ -806,6 +806,8 @@
title: TVP title: TVP
- local: model_doc/udop - local: model_doc/udop
title: UDOP title: UDOP
- local: model_doc/video_llava
title: VideoLlava
- local: model_doc/vilt - local: model_doc/vilt
title: ViLT title: ViLT
- local: model_doc/vipllava - local: model_doc/vipllava
......
...@@ -304,6 +304,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -304,6 +304,7 @@ Flax), PyTorch, and/or TensorFlow.
| [UnivNet](model_doc/univnet) | ✅ | ❌ | ❌ | | [UnivNet](model_doc/univnet) | ✅ | ❌ | ❌ |
| [UPerNet](model_doc/upernet) | ✅ | ❌ | ❌ | | [UPerNet](model_doc/upernet) | ✅ | ❌ | ❌ |
| [VAN](model_doc/van) | ✅ | ❌ | ❌ | | [VAN](model_doc/van) | ✅ | ❌ | ❌ |
| [VideoLlava](model_doc/video_llava) | ✅ | ❌ | ❌ |
| [VideoMAE](model_doc/videomae) | ✅ | ❌ | ❌ | | [VideoMAE](model_doc/videomae) | ✅ | ❌ | ❌ |
| [ViLT](model_doc/vilt) | ✅ | ❌ | ❌ | | [ViLT](model_doc/vilt) | ✅ | ❌ | ❌ |
| [VipLlava](model_doc/vipllava) | ✅ | ❌ | ❌ | | [VipLlava](model_doc/vipllava) | ✅ | ❌ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Video-LLaVA
## Overview
Video-LLaVa is an open-source multimodal LLM trained by fine-tuning LlamA/Vicuna on multimodal instruction-following data generated by Llava1.5 and VideChat. It is an auto-regressive language model, based on the transformer architecture. Video-LLaVa unifies visual representations to the language feature space, and enables an LLM to perform visual reasoning capabilities on both images and videos simultaneously.
The Video-LLaVA model was proposed in [Video-LLaVA: Learning United Visual Representation by Alignment Before Projection](https://arxiv.org/abs/2311.10122) by Bin Lin, Yang Ye, Bin Zhu, Jiaxi Cui, Munang Ning, Peng Jin, Li Yuan.
The abstract from the paper is the following:
*The Large Vision-Language Model (LVLM) has enhanced the performance of various downstream tasks in
visual-language understanding. Most existing approaches
encode images and videos into separate feature spaces,
which are then fed as inputs to large language models.
However, due to the lack of unified tokenization for images and videos, namely misalignment before projection, it
becomes challenging for a Large Language Model (LLM)
to learn multi-modal interactions from several poor projection layers. In this work, we unify visual representation into the language feature space to advance the foundational LLM towards a unified LVLM. As a result, we establish a simple but robust LVLM baseline, Video-LLaVA,
which learns from a mixed dataset of images and videos,
mutually enhancing each other. Video-LLaVA achieves superior performances on a broad range of 9 image benchmarks across 5 image question-answering datasets and 4
image benchmark toolkits. Additionally, our Video-LLaVA
also outperforms Video-ChatGPT by 5.8%, 9.9%, 18.6%,
and 10.1% on MSRVTT, MSVD, TGIF, and ActivityNet, respectively. Notably, extensive experiments demonstrate that
Video-LLaVA mutually benefits images and videos within
a unified visual representation, outperforming models designed specifically for images or videos. We aim for this
work to provide modest insights into the multi-modal inputs
for the LLM*
Tips:
- We advise users to use padding_side="left" when computing batched generation as it leads to more accurate results. Simply make sure to call processor.tokenizer.padding_side = "left" before generating.
- Note the model has not been explicitly trained to process multiple images/videos in the same prompt, although this is technically possible, you may experience inaccurate results.
- For better results, we recommend users prompt the model with the correct prompt format:
```python
import av
import torch
import numpy as np
import requests
from PIL import Image
from transformers import VideoLlavaForConditionalGeneration, VideoLlavaProcessor
def read_video_pyav(container, indices):
'''
Decode the video with PyAV decoder.
Args:
container (`av.container.input.InputContainer`): PyAV container.
indices (`List[int]`): List of frame indices to decode.
Returns:
result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
'''
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])
model = VideoLlavaForConditionalGeneration.from_pretrained("RaushanTurganbay/video-llava-7b-hf", device_map="auto")
processor = VideoLlavaProcessor.from_pretrained("RaushanTurganbay/video-llava-7b-hf")
video_path = hf_hub_download(repo_id="raushan-testing-hf/videos-test", filename="sample_demo_1.mp4", repo_type="dataset")
container = av.open(video_path)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
video = read_video_pyav(container, indices)
prompt = "USER: <video>Why is this funny? ASSISTANT:"
inputs = processor(text=prompt, videos=video, return_tensors="pt")
out = model.generate(**inputs, max_new_tokens=40)
print(processor.batch_decode(out, skip_special_tokens=True, clean_up_tokenization_spaces=True))
```
For multiple turns conversation change the prompt to:
```bash
"USER: <video>What do you see in this video? ASSISTANT: A baby reading a book. USER: Why is the it funny? ASSISTANT:"
```
- Note that the video inputs should have exactly 8 frames at the input, since the models were trained in that setting.
This model was contributed by [RaushanTurganbay](https://huggingface.co/RaushanTurganbay).
The original code can be found [here](https://github.com/PKU-YuanGroup/Video-LLaVA).
## VideoLlavaConfig
[[autodoc]] VideoLlavaConfig
## VideoLlavaImageProcessor
[[autodoc]] VideoLlavaImageProcessor
## VideoLlavaProcessor
[[autodoc]] VideoLlavaProcessor
## VideoLlavaForConditionalGeneration
[[autodoc]] VideoLlavaForConditionalGeneration
- forward
...@@ -56,6 +56,7 @@ FlashAttention-2 is currently supported for the following architectures: ...@@ -56,6 +56,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) * [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) * [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) * [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100) * [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) * [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
......
...@@ -733,6 +733,7 @@ _import_structure = { ...@@ -733,6 +733,7 @@ _import_structure = {
"UnivNetFeatureExtractor", "UnivNetFeatureExtractor",
], ],
"models.upernet": ["UperNetConfig"], "models.upernet": ["UperNetConfig"],
"models.video_llava": ["VideoLlavaConfig"],
"models.videomae": ["VideoMAEConfig"], "models.videomae": ["VideoMAEConfig"],
"models.vilt": [ "models.vilt": [
"ViltConfig", "ViltConfig",
...@@ -3239,6 +3240,14 @@ else: ...@@ -3239,6 +3240,14 @@ else:
"UperNetPreTrainedModel", "UperNetPreTrainedModel",
] ]
) )
_import_structure["models.video_llava"].extend(
[
"VideoLlavaForConditionalGeneration",
"VideoLlavaImageProcessor",
"VideoLlavaPreTrainedModel",
"VideoLlavaProcessor",
]
)
_import_structure["models.videomae"].extend( _import_structure["models.videomae"].extend(
[ [
"VideoMAEForPreTraining", "VideoMAEForPreTraining",
...@@ -5309,6 +5318,7 @@ if TYPE_CHECKING: ...@@ -5309,6 +5318,7 @@ if TYPE_CHECKING:
UnivNetFeatureExtractor, UnivNetFeatureExtractor,
) )
from .models.upernet import UperNetConfig from .models.upernet import UperNetConfig
from .models.video_llava import VideoLlavaConfig
from .models.videomae import VideoMAEConfig from .models.videomae import VideoMAEConfig
from .models.vilt import ( from .models.vilt import (
ViltConfig, ViltConfig,
...@@ -7425,6 +7435,12 @@ if TYPE_CHECKING: ...@@ -7425,6 +7435,12 @@ if TYPE_CHECKING:
UperNetForSemanticSegmentation, UperNetForSemanticSegmentation,
UperNetPreTrainedModel, UperNetPreTrainedModel,
) )
from .models.video_llava import (
VideoLlavaForConditionalGeneration,
VideoLlavaImageProcessor,
VideoLlavaPreTrainedModel,
VideoLlavaProcessor,
)
from .models.videomae import ( from .models.videomae import (
VideoMAEForPreTraining, VideoMAEForPreTraining,
VideoMAEForVideoClassification, VideoMAEForVideoClassification,
......
...@@ -65,6 +65,9 @@ ImageInput = Union[ ...@@ -65,6 +65,9 @@ ImageInput = Union[
] # noqa ] # noqa
VideoInput = Union[np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]] # noqa
class ChannelDimension(ExplicitEnum): class ChannelDimension(ExplicitEnum):
FIRST = "channels_first" FIRST = "channels_first"
LAST = "channels_last" LAST = "channels_last"
......
...@@ -242,6 +242,7 @@ from . import ( ...@@ -242,6 +242,7 @@ from . import (
unispeech_sat, unispeech_sat,
univnet, univnet,
upernet, upernet,
video_llava,
videomae, videomae,
vilt, vilt,
vipllava, vipllava,
......
...@@ -255,6 +255,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -255,6 +255,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("univnet", "UnivNetConfig"), ("univnet", "UnivNetConfig"),
("upernet", "UperNetConfig"), ("upernet", "UperNetConfig"),
("van", "VanConfig"), ("van", "VanConfig"),
("video_llava", "VideoLlavaConfig"),
("videomae", "VideoMAEConfig"), ("videomae", "VideoMAEConfig"),
("vilt", "ViltConfig"), ("vilt", "ViltConfig"),
("vipllava", "VipLlavaConfig"), ("vipllava", "VipLlavaConfig"),
...@@ -542,6 +543,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -542,6 +543,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("univnet", "UnivNet"), ("univnet", "UnivNet"),
("upernet", "UPerNet"), ("upernet", "UPerNet"),
("van", "VAN"), ("van", "VAN"),
("video_llava", "VideoLlava"),
("videomae", "VideoMAE"), ("videomae", "VideoMAE"),
("vilt", "ViLT"), ("vilt", "ViLT"),
("vipllava", "VipLlava"), ("vipllava", "VipLlava"),
......
...@@ -116,6 +116,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -116,6 +116,7 @@ IMAGE_PROCESSOR_MAPPING_NAMES = OrderedDict(
("udop", "LayoutLMv3ImageProcessor"), ("udop", "LayoutLMv3ImageProcessor"),
("upernet", "SegformerImageProcessor"), ("upernet", "SegformerImageProcessor"),
("van", "ConvNextImageProcessor"), ("van", "ConvNextImageProcessor"),
("video_llava", "VideoLlavaImageProcessor"),
("videomae", "VideoMAEImageProcessor"), ("videomae", "VideoMAEImageProcessor"),
("vilt", "ViltImageProcessor"), ("vilt", "ViltImageProcessor"),
("vipllava", "CLIPImageProcessor"), ("vipllava", "CLIPImageProcessor"),
......
...@@ -328,6 +328,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -328,6 +328,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("tvlt", "TvltForPreTraining"), ("tvlt", "TvltForPreTraining"),
("unispeech", "UniSpeechForPreTraining"), ("unispeech", "UniSpeechForPreTraining"),
("unispeech-sat", "UniSpeechSatForPreTraining"), ("unispeech-sat", "UniSpeechSatForPreTraining"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("videomae", "VideoMAEForPreTraining"), ("videomae", "VideoMAEForPreTraining"),
("vipllava", "VipLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"),
("visual_bert", "VisualBertForPreTraining"), ("visual_bert", "VisualBertForPreTraining"),
...@@ -700,6 +701,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict( ...@@ -700,6 +701,7 @@ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
("llava_next", "LlavaNextForConditionalGeneration"), ("llava_next", "LlavaNextForConditionalGeneration"),
("paligemma", "PaliGemmaForConditionalGeneration"), ("paligemma", "PaliGemmaForConditionalGeneration"),
("pix2struct", "Pix2StructForConditionalGeneration"), ("pix2struct", "Pix2StructForConditionalGeneration"),
("video_llava", "VideoLlavaForConditionalGeneration"),
("vipllava", "VipLlavaForConditionalGeneration"), ("vipllava", "VipLlavaForConditionalGeneration"),
("vision-encoder-decoder", "VisionEncoderDecoderModel"), ("vision-encoder-decoder", "VisionEncoderDecoderModel"),
] ]
......
...@@ -90,6 +90,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict( ...@@ -90,6 +90,7 @@ PROCESSOR_MAPPING_NAMES = OrderedDict(
("tvp", "TvpProcessor"), ("tvp", "TvpProcessor"),
("unispeech", "Wav2Vec2Processor"), ("unispeech", "Wav2Vec2Processor"),
("unispeech-sat", "Wav2Vec2Processor"), ("unispeech-sat", "Wav2Vec2Processor"),
("video_llava", "VideoLlavaProcessor"),
("vilt", "ViltProcessor"), ("vilt", "ViltProcessor"),
("vipllava", "LlavaProcessor"), ("vipllava", "LlavaProcessor"),
("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"), ("vision-text-dual-encoder", "VisionTextDualEncoderProcessor"),
......
...@@ -470,6 +470,7 @@ else: ...@@ -470,6 +470,7 @@ else:
"T5TokenizerFast" if is_tokenizers_available() else None, "T5TokenizerFast" if is_tokenizers_available() else None,
), ),
), ),
("video_llava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("vilt", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)), ("vipllava", ("LlamaTokenizer", "LlamaTokenizerFast" if is_tokenizers_available() else None)),
("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), ("visual_bert", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)),
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
"configuration_video_llava": ["VideoLlavaConfig"],
"processing_video_llava": ["VideoLlavaProcessor"],
}
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["image_processing_video_llava"] = ["VideoLlavaImageProcessor"]
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_video_llava"] = [
"VideoLlavaPreTrainedModel",
"VideoLlavaForConditionalGeneration",
]
if TYPE_CHECKING:
from .configuration_video_llava import (
VideoLlavaConfig,
)
from .image_processing_video_llava import VideoLlavaProcessor
try:
if not is_vision_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .image_processing_video_llava import VideoLlavaImageProcessor
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_video_llava import (
VideoLlavaForConditionalGeneration,
VideoLlavaPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 Microsoft Research & University of Wisconsin-Madison and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" VideoLlava model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ..auto import CONFIG_MAPPING
logger = logging.get_logger(__name__)
class VideoLlavaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`VideoLlavaForConditionalGeneration`]. It is used to instantiate an
VideoLlava model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the like LanguageBind/Video-LLaVA-7B-hf.
e.g. [LanguageBind/Video-LLaVA-7B-hf](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf)
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vision_config (`VideoLlavaVisionConfig`, *optional*):
Custom vision config or dict. Defaults to `CLIPVisionConfig` if not indicated.
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
Defaults to `LlamaConfig` if not indicated.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
image_token_index (`int`, *optional*, defaults to 32000):
The image token index to encode the image prompt.
video_token_index (`int`, *optional*, defaults to 32001):
The video token index to encode the image prompt.
projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
The activation function used by the multimodal projector.
vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
The feature selection strategy used to select the vision feature from the CLIP backbone.
Can be either "full" to select all features or "default" to select features without `CLS`.
vision_feature_layer (`int`, *optional*, defaults to -2):
The index of the layer to select the vision feature.
Example:
```python
>>> from transformers import VideoLlavaForConditionalGeneration, VideoLlavaConfig, CLIPVisionConfig, LlamaConfig
>>> # Initializing a CLIP-vision config
>>> vision_config = CLIPVisionConfig()
>>> # Initializing a Llama config
>>> text_config = LlamaConfig()
>>> # Initializing a VideoLlava video_llava-1.5-7b style configuration
>>> configuration = VideoLlavaConfig(vision_config, text_config)
>>> # Initializing a model from the video_llava-1.5-7b style configuration
>>> model = VideoLlavaForConditionalGeneration(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "video_llava"
is_composition = False
def __init__(
self,
vision_config=None,
text_config=None,
ignore_index=-100,
image_token_index=32000,
video_token_index=32001,
projector_hidden_act="gelu",
vision_feature_select_strategy="default",
vision_feature_layer=-2,
**kwargs,
):
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.video_token_index = video_token_index
self.projector_hidden_act = projector_hidden_act
self.vision_feature_select_strategy = vision_feature_select_strategy
self.vision_feature_layer = vision_feature_layer
self.vision_config = vision_config
if isinstance(self.vision_config, dict):
if "model_type" not in vision_config:
vision_config["model_type"] = "clip_vision_model"
logger.warning("Key=`model_type` not found in vision config, setting it to `clip_vision_model`")
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
elif vision_config is None:
self.vision_config = CONFIG_MAPPING["clip_vision_model"](
intermediate_size=4096,
hidden_size=1024,
patch_size=14,
image_size=224,
num_hidden_layers=24,
num_attention_heads=16,
vocab_size=32000,
projection_dim=768,
)
if isinstance(text_config, dict):
if "model_type" not in text_config:
text_config["model_type"] = "llama"
logger.warning("Key=`model_type` not found in text config, setting it to `llama`")
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["llama"]()
self.text_config = text_config
super().__init__(**kwargs)
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from huggingface_hub import hf_hub_download
from transformers import (
AddedToken,
AutoConfig,
AutoTokenizer,
VideoLlavaConfig,
VideoLlavaForConditionalGeneration,
VideoLlavaImageProcessor,
VideoLlavaProcessor,
)
EPILOG_TXT = """Example:
python transformers/src/transformers/models/video_llava/convert_video_llava_weights_to_hf.py --text_model_id lmsys/vicuna-7b-v1.5 --vision_model_id openai/clip-vit-large-patch14 --output_hub_path org/video_llava-7b --old_state_dict_id LanguageBind/Video-LLaVA-7B
Example for creating the old state dict file with Python:
import torch
from video_llava.model.language_model.video_llava import VideoLlavaForCausalLM
# load model
kwargs = {"device_map": "auto", "torch_dtype": torch.float16}
model = VideoLlavaForCausalLM.from_pretrained("LanguageBind/Video-LLaVA-7B-hf", low_cpu_mem_usage=True, **kwargs)
# load vision tower
model.get_vision_tower().load_model()
# Save state dict
torch.save(model.state_dict(), "tmp/hf_models/video_llava-7b/model_state_dict.bin")
"""
KEYS_TO_MODIFY_MAPPING = {
"model.video_tower.video_tower": "video_tower",
"model.image_tower.image_tower": "image_tower",
"model.mm_projector": "multi_modal_projector",
"model": "language_model.model",
"lm_head": "language_model.lm_head",
"video_tower": "video_tower.vision_model",
"image_tower": "image_tower.vision_model",
"multi_modal_projector.0": "multi_modal_projector.linear_1",
"multi_modal_projector.2": "multi_modal_projector.linear_2",
}
def convert_state_dict_to_hf(state_dict):
new_state_dict = {}
for key, value in state_dict.items():
if key.endswith(".inv_freq"):
continue
for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in key:
key = key.replace(key_to_modify, new_key)
new_state_dict[key] = value
return new_state_dict
def convert_video_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
torch.set_default_dtype(torch.float16)
text_config = AutoConfig.from_pretrained(text_model_id)
tokenizer = AutoTokenizer.from_pretrained(text_model_id)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)
tokenizer.add_tokens(AddedToken("<video>", special=True, normalized=False), special_tokens=True)
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.padding_side = "left"
image_processor = VideoLlavaImageProcessor.from_pretrained(vision_model_id)
processor = VideoLlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
config = VideoLlavaConfig(text_config=text_config)
config.pad_token_id = 32002
with torch.device("meta"):
model = VideoLlavaForConditionalGeneration(config)
model_state_dict = set(model.state_dict().keys())
# Pad to 64 for performance reasons
pad_shape = 64
state_dict_temp = "pytorch_model-0000{i}-of-00002.bin"
for shard in range(1, 3):
state_dict_path = hf_hub_download(old_state_dict_id, state_dict_temp.format(i=shard))
state_dict = torch.load(state_dict_path, map_location="cpu")
state_dict = convert_state_dict_to_hf(state_dict)
model.load_state_dict(state_dict, strict=False, assign=True)
model_state_dict -= set(state_dict.keys())
if len(model_state_dict) > 0:
raise RuntimeError(f"Missing keys in state dict: {model_state_dict}")
pre_expansion_embeddings = model.language_model.model.embed_tokens.weight.data
mu = torch.mean(pre_expansion_embeddings, dim=0).float()
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
# We add an image and video token so we resize the model
model.resize_token_embeddings(config.text_config.vocab_size + 3, pad_shape)
model.language_model.model.embed_tokens.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[32000:].shape[0]))),
dim=0,
)
model.language_model.lm_head.weight.data[32000:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[32000:].shape[0]))),
dim=0,
)
model.push_to_hub(output_hub_path)
processor.push_to_hub(output_hub_path)
def main():
parser = argparse.ArgumentParser(
epilog=EPILOG_TXT,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"--text_model_id",
help="Hub location of the text model",
)
parser.add_argument(
"--vision_model_id",
help="Hub location of the vision model",
)
parser.add_argument(
"--output_hub_path",
help="Location on the hub of the converted model",
)
parser.add_argument(
"--old_state_dict_id",
help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`",
)
args = parser.parse_args()
convert_video_llava_llama_to_hf(
args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id
)
if __name__ == "__main__":
main()
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Video-LLaVA."""
from typing import Dict, List, Optional, Union
import numpy as np
from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from ...image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
resize,
to_channel_dimension_format,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
VideoInput,
infer_channel_dimension_format,
is_scaled_image,
is_valid_image,
make_list_of_images,
to_numpy_array,
valid_images,
validate_kwargs,
validate_preprocess_arguments,
)
from ...utils import TensorType, is_vision_available, logging
logger = logging.get_logger(__name__)
if is_vision_available():
import PIL
def make_batched_videos(videos) -> List[VideoInput]:
if isinstance(videos, (list, tuple)) and isinstance(videos[0], (list, tuple)) and is_valid_image(videos[0][0]):
return videos
elif isinstance(videos, (list, tuple)) and is_valid_image(videos[0]) and len(videos[0].shape) == 4:
return [list(video) for video in videos]
elif is_valid_image(videos) and len(videos.shape) == 4:
return [list(videos)]
raise ValueError(f"Could not make batched video from {videos}")
class VideoLlavaImageProcessor(BaseImageProcessor):
r"""
Constructs a CLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
`do_resize` in the `preprocess` method.
size (`Dict[str, int]` *optional*, defaults to `{"shortest_edge": 224}`):
Size of the image after resizing. The shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio. Can be overridden by `size` in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to center crop the image to the specified `crop_size`. Can be overridden by `do_center_crop` in the
`preprocess` method.
crop_size (`Dict[str, int]` *optional*, defaults to 224):
Size of the output image after applying `center_crop`. Can be overridden by `crop_size` in the `preprocess`
method.
do_rescale (`bool`, *optional*, defaults to `True`):
Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by `do_normalize` in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `[0.48145466, 0.4578275, 0.40821073]`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `[0.26862954, 0.26130258, 0.27577711]`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_center_crop: bool = True,
crop_size: Dict[str, int] = None,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"shortest_edge": 224}
size = get_size_dict(size, default_to_square=False)
crop_size = crop_size if crop_size is not None else {"height": 224, "width": 224}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self._valid_processor_keys = [
"images",
"videos",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format",
]
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge
resized to keep the input aspect ratio.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use when resiizing the image.
data_format (`str` or `ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format of the input image. If not provided, it will be inferred.
"""
default_to_square = True
if "shortest_edge" in size:
size = size["shortest_edge"]
default_to_square = False
elif "height" in size and "width" in size:
size = (size["height"], size["width"])
else:
raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.")
output_size = get_resize_output_image_size(
image,
size=size,
default_to_square=default_to_square,
input_data_format=input_data_format,
)
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: List[ImageInput] = None,
videos: List[VideoInput] = None,
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`, *optional*):
List of images to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
videos (`VideoInput`, *optional*):
List of videos to preprocess. Expects a single or batch of videos with pixel values ranging from 0 to 255. If
passing in videos with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Size of the image after resizing. Shortest edge of the image is resized to size["shortest_edge"], with
the longest edge resized to keep the input aspect ratio.
resample (`int`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
has an effect if `do_resize` is set to `True`.
do_center_crop (`bool`, *optional*, defaults to `self.do_center_crop`):
Whether to center crop the image.
crop_size (`Dict[str, int]`, *optional*, defaults to `self.crop_size`):
Size of the center crop. Only has an effect if `do_center_crop` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image.
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
`True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
size = get_size_dict(size, param_name="size", default_to_square=False)
resample = resample if resample is not None else self.resample
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
crop_size = crop_size if crop_size is not None else self.crop_size
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
if images is not None:
images = make_list_of_images(images)
if videos is not None:
videos = make_batched_videos(videos)
validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys)
if (videos is not None and not valid_images(videos)) or (images is not None and not valid_images(images)):
raise ValueError(
"Invalid input type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
data = {}
if videos is not None:
pixel_values_videos = [
[
self._preprocess_image(
image=frame,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_convert_rgb=do_convert_rgb,
data_format=data_format,
input_data_format=input_data_format,
)
for frame in video
]
for video in videos
]
data["pixel_values_videos"] = pixel_values_videos
if images is not None:
pixel_values_images = [
self._preprocess_image(
image=image,
do_resize=do_resize,
size=size,
resample=resample,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_convert_rgb=do_convert_rgb,
data_format=data_format,
input_data_format=input_data_format,
)
for image in images
]
data["pixel_values_images"] = pixel_values_images
encoded_outputs = BatchFeature(data, tensor_type=return_tensors)
return encoded_outputs
def _preprocess_image(
self,
image: ImageInput = None,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_center_crop: bool = None,
crop_size: int = None,
do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
validate_preprocess_arguments(
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
size=size,
resample=resample,
)
# PIL RGBA images are converted to RGB
if do_convert_rgb:
image = convert_to_rgb(image)
# All transformations expect numpy arrays.
image = to_numpy_array(image)
if is_scaled_image(image) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images/video frames. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(image)
if do_resize:
image = self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
if do_center_crop:
image = self.center_crop(image=image, size=crop_size, input_data_format=input_data_format)
if do_rescale:
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
if do_normalize:
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format)
return image
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Processor class for VideoLlava.
"""
from typing import List, Optional, Union
from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
class VideoLlavaProcessor(ProcessorMixin):
r"""
Constructs a VideoLlava processor which wraps a VideoLlava image processor and a Llava tokenizer into a single processor.
[`VideoLlavaProcessor`] offers all the functionalities of [`VideoLlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
[`~VideoLlavaProcessor.__call__`] and [`~VideoLlavaProcessor.decode`] for more information.
Args:
image_processor ([`VideoLlavaImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = "VideoLlavaImageProcessor"
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
videos: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to
VideoLlavaImageProcessor's [`~VideoLlavaImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
of the above two methods for more information.
Args:
text (`TextInput`, `PreTokenizedInput`, `List[TextInput]`, `List[PreTokenizedInput]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
Video frames to preprocess. Expects a single or batch of video frames in NumPy array or PyTorch
tensor. Each video should be of shape (T, C, H, W), where T is number of frames, C is
number of channels, H and W are image height and width.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
data = {}
if images is not None or videos is not None:
encoded_images = self.image_processor(images=images, videos=videos, return_tensors=return_tensors)
data.update(encoded_images)
text_inputs = self.tokenizer(
text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
)
data.update(text_inputs)
return BatchFeature(data=data)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
@property
# Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment