Unverified Commit c1e139c2 authored by Naman Garg's avatar Naman Garg Committed by GitHub
Browse files

Adding hiera (#30356)



* initialized Structure

* Updated variable names

* Added Config class, basic HF setup, convert_to_hf

* Fixed Convert function, added hiera to HF files, Initilized test files

* better naming for x in forward pass

* Moved utils to hiera

* Change hiera -> hiera_model

* Fixed integration into tranformers

* Fix: Convert Checkpoint

* added documentation for hiera

* added documentation for hiera

* added Docstings to models, Transformers based changes

* make style and quality

* make style and quality

* Integration & Block tests running

* Fixed bugs

* initialized Structure

* Updated variable names

* Added Config class, basic HF setup, convert_to_hf

* Fixed Convert function, added hiera to HF files, Initilized test files

* better naming for x in forward pass

* Moved utils to hiera

* Change hiera -> hiera_model

* Fixed integration into tranformers

* Fix: Convert Checkpoint

* added documentation for hiera

* added documentation for hiera

* added Docstings to models, Transformers based changes

* make style and quality

* make style and quality

* Integration & Block tests running

* Fixed bugs

* Removed tim dependency

* added HieraBlock

* fixed: Model name

* added tests for HieraModel, HieraBlock

* fixed imports

* fixed quality & copies

* Fixes

* Update docs/source/en/model_doc/hiera.md

Fix name
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/hiera.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/model_doc/hiera.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/transformers/models/hiera/configuration_hiera.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/transformers/models/hiera/configuration_hiera.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/transformers/models/hiera/modeling_hiera.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/transformers/models/hiera/modeling_hiera.py
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* Fixed formatting

* Code quality & Import differences

* quality and repo-consistency fix

* fixed no torch error

* Docstring fix

* Docstring fix

* doc string fix

* fixed example usage

* Resolved issues in modeling_hiera

* Removed Hiera MAE

* Added test and resolved bug

* fixed doc string

* First commit

* Finished conversion script and model forward working

* Resolved all issues

* nits

* Improving tests

* Nits

* More nits

* Improving HieraForMaskedImageModeling

* More improvements and nits

* Fixed docstrings of outputs

* More fixes

* More imrpovments

* Updated conversion script

* Fixed docstrings

* Improved tests

* Fixed attentou outputs test

* All tests green

* Removed unnecessary file

* contribution attribution

* Resolved a few issues

* Resolved Comments

* Updated model repo id and fixed bugs

* Removed loss print

* Make tests green

* Updated docstrings

* Fix style

* Fixed num_heads in config

* Removed unnecessary video checkpoint related code in the conversion script

* Fix style

* Changed atol in conversion script

* HieraConfig

* Fix copies

* Fixed typo

* Resolved few issues

* make

* converted conv_nd -> nn.Module

* Removed video complexities

* Removed video complexities

* fix style

* Addressing comments

* Update src/transformers/models/hiera/modeling_hiera.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/hiera/modeling_hiera.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/hiera/modeling_hiera.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix style

* Fixed tests

* Fixed typo

* Fixed interpolate test

* Made torch fx compatible

* Made sure imageprocesor is correct

* Addressed comments

* Noise directly as torch

* Remove unnecesary attr

* Added return_dit

* Update src/transformers/models/hiera/__init__.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Updated checkpoints

* [run_slow] hiera

* Fixed device mismatch

* [run_slow] hiera

* Fixed GPU tests

* [run_slow] hiera

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-29-50.us-east-2.compute.internal>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: default avatarEduardo Pacheco <eduardo.pach@hotmail.com>
Co-authored-by: default avatarEduardo Pacheco <69953243+EduardoPach@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 574e68d5
...@@ -603,6 +603,8 @@ ...@@ -603,6 +603,8 @@
title: FocalNet title: FocalNet
- local: model_doc/glpn - local: model_doc/glpn
title: GLPN title: GLPN
- local: model_doc/hiera
title: Hiera
- local: model_doc/imagegpt - local: model_doc/imagegpt
title: ImageGPT title: ImageGPT
- local: model_doc/levit - local: model_doc/levit
...@@ -680,6 +682,8 @@ ...@@ -680,6 +682,8 @@
title: CLAP title: CLAP
- local: model_doc/encodec - local: model_doc/encodec
title: EnCodec title: EnCodec
- local: model_doc/hiera
title: Hiera
- local: model_doc/hubert - local: model_doc/hubert
title: Hubert title: Hubert
- local: model_doc/mctct - local: model_doc/mctct
......
...@@ -159,6 +159,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -159,6 +159,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ | | [Grounding DINO](model_doc/grounding-dino) | ✅ | ❌ | ❌ |
| [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ | | [GroupViT](model_doc/groupvit) | ✅ | ✅ | ❌ |
| [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ | | [HerBERT](model_doc/herbert) | ✅ | ✅ | ✅ |
| [Hiera](model_doc/hiera) | ✅ | ❌ | ❌ |
| [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ | | [Hubert](model_doc/hubert) | ✅ | ✅ | ❌ |
| [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ | | [I-BERT](model_doc/ibert) | ✅ | ❌ | ❌ |
| [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ | | [IDEFICS](model_doc/idefics) | ✅ | ✅ | ❌ |
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Hiera
## Overview
Hiera was proposed in [Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles](https://arxiv.org/abs/2306.00989) by Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan, Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed, Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer
The paper introduces "Hiera," a hierarchical Vision Transformer that simplifies the architecture of modern hierarchical vision transformers by removing unnecessary components without compromising on accuracy or efficiency. Unlike traditional transformers that add complex vision-specific components to improve supervised classification performance, Hiera demonstrates that such additions, often termed "bells-and-whistles," are not essential for high accuracy. By leveraging a strong visual pretext task (MAE) for pretraining, Hiera retains simplicity and achieves superior accuracy and speed both in inference and training across various image and video recognition tasks. The approach suggests that spatial biases required for vision tasks can be effectively learned through proper pretraining, eliminating the need for added architectural complexity.
The abstract from the paper is the following:
*Modern hierarchical vision transformers have added several vision-specific components in the pursuit of supervised classification performance. While these components lead to effective accuracies and attractive FLOP counts, the added complexity actually makes these transformers slower than their vanilla ViT counterparts. In this paper, we argue that this additional bulk is unnecessary. By pretraining with a strong visual pretext task (MAE), we can strip out all the bells-and-whistles from a state-of-the-art multi-stage vision transformer without losing accuracy. In the process, we create Hiera, an extremely simple hierarchical vision transformer that is more accurate than previous models while being significantly faster both at inference and during training. We evaluate Hiera on a variety of tasks for image and video recognition. Our code and models are available at https://github.com/facebookresearch/hiera.*
This model was a joint contibution by [EduardoPacheco](https://huggingface.co/EduardoPacheco) and [namangarg110](https://huggingface.co/namangarg110). The original code can be found [here] (https://github.com/facebookresearch/hiera).
## HieraConfig
[[autodoc]] HieraConfig
## HieraModel
[[autodoc]] HieraModel
- forward
## HieraForPreTraining
[[autodoc]] HieraForPreTraining
- forward
## HieraForImageClassification
[[autodoc]] HieraForImageClassification
- forward
...@@ -462,6 +462,7 @@ _import_structure = { ...@@ -462,6 +462,7 @@ _import_structure = {
"GroupViTVisionConfig", "GroupViTVisionConfig",
], ],
"models.herbert": ["HerbertTokenizer"], "models.herbert": ["HerbertTokenizer"],
"models.hiera": ["HieraConfig"],
"models.hubert": ["HubertConfig"], "models.hubert": ["HubertConfig"],
"models.ibert": ["IBertConfig"], "models.ibert": ["IBertConfig"],
"models.idefics": ["IdeficsConfig"], "models.idefics": ["IdeficsConfig"],
...@@ -2285,6 +2286,15 @@ else: ...@@ -2285,6 +2286,15 @@ else:
"GroupViTVisionModel", "GroupViTVisionModel",
] ]
) )
_import_structure["models.hiera"].extend(
[
"HieraBackbone",
"HieraForImageClassification",
"HieraForPreTraining",
"HieraModel",
"HieraPreTrainedModel",
]
)
_import_structure["models.hubert"].extend( _import_structure["models.hubert"].extend(
[ [
"HubertForCTC", "HubertForCTC",
...@@ -5112,6 +5122,7 @@ if TYPE_CHECKING: ...@@ -5112,6 +5122,7 @@ if TYPE_CHECKING:
GroupViTVisionConfig, GroupViTVisionConfig,
) )
from .models.herbert import HerbertTokenizer from .models.herbert import HerbertTokenizer
from .models.hiera import HieraConfig
from .models.hubert import HubertConfig from .models.hubert import HubertConfig
from .models.ibert import IBertConfig from .models.ibert import IBertConfig
from .models.idefics import ( from .models.idefics import (
...@@ -6795,6 +6806,13 @@ if TYPE_CHECKING: ...@@ -6795,6 +6806,13 @@ if TYPE_CHECKING:
GroupViTTextModel, GroupViTTextModel,
GroupViTVisionModel, GroupViTVisionModel,
) )
from .models.hiera import (
HieraBackbone,
HieraForImageClassification,
HieraForPreTraining,
HieraModel,
HieraPreTrainedModel,
)
from .models.hubert import ( from .models.hubert import (
HubertForCTC, HubertForCTC,
HubertForSequenceClassification, HubertForSequenceClassification,
......
...@@ -105,6 +105,7 @@ from . import ( ...@@ -105,6 +105,7 @@ from . import (
grounding_dino, grounding_dino,
groupvit, groupvit,
herbert, herbert,
hiera,
hubert, hubert,
ibert, ibert,
idefics, idefics,
......
...@@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict( ...@@ -122,6 +122,7 @@ CONFIG_MAPPING_NAMES = OrderedDict(
("graphormer", "GraphormerConfig"), ("graphormer", "GraphormerConfig"),
("grounding-dino", "GroundingDinoConfig"), ("grounding-dino", "GroundingDinoConfig"),
("groupvit", "GroupViTConfig"), ("groupvit", "GroupViTConfig"),
("hiera", "HieraConfig"),
("hubert", "HubertConfig"), ("hubert", "HubertConfig"),
("ibert", "IBertConfig"), ("ibert", "IBertConfig"),
("idefics", "IdeficsConfig"), ("idefics", "IdeficsConfig"),
...@@ -403,6 +404,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -403,6 +404,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("grounding-dino", "Grounding DINO"), ("grounding-dino", "Grounding DINO"),
("groupvit", "GroupViT"), ("groupvit", "GroupViT"),
("herbert", "HerBERT"), ("herbert", "HerBERT"),
("hiera", "Hiera"),
("hubert", "Hubert"), ("hubert", "Hubert"),
("ibert", "I-BERT"), ("ibert", "I-BERT"),
("idefics", "IDEFICS"), ("idefics", "IDEFICS"),
......
...@@ -85,6 +85,7 @@ else: ...@@ -85,6 +85,7 @@ else:
("glpn", ("GLPNImageProcessor",)), ("glpn", ("GLPNImageProcessor",)),
("grounding-dino", ("GroundingDinoImageProcessor",)), ("grounding-dino", ("GroundingDinoImageProcessor",)),
("groupvit", ("CLIPImageProcessor",)), ("groupvit", ("CLIPImageProcessor",)),
("hiera", ("BitImageProcessor",)),
("idefics", ("IdeficsImageProcessor",)), ("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)), ("idefics2", ("Idefics2ImageProcessor",)),
("imagegpt", ("ImageGPTImageProcessor",)), ("imagegpt", ("ImageGPTImageProcessor",)),
......
...@@ -119,6 +119,7 @@ MODEL_MAPPING_NAMES = OrderedDict( ...@@ -119,6 +119,7 @@ MODEL_MAPPING_NAMES = OrderedDict(
("graphormer", "GraphormerModel"), ("graphormer", "GraphormerModel"),
("grounding-dino", "GroundingDinoModel"), ("grounding-dino", "GroundingDinoModel"),
("groupvit", "GroupViTModel"), ("groupvit", "GroupViTModel"),
("hiera", "HieraModel"),
("hubert", "HubertModel"), ("hubert", "HubertModel"),
("ibert", "IBertModel"), ("ibert", "IBertModel"),
("idefics", "IdeficsModel"), ("idefics", "IdeficsModel"),
...@@ -295,6 +296,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -295,6 +296,7 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("gpt2", "GPT2LMHeadModel"), ("gpt2", "GPT2LMHeadModel"),
("gpt_bigcode", "GPTBigCodeForCausalLM"), ("gpt_bigcode", "GPTBigCodeForCausalLM"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"), ("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("hiera", "HieraForPreTraining"),
("ibert", "IBertForMaskedLM"), ("ibert", "IBertForMaskedLM"),
("idefics", "IdeficsForVisionText2Text"), ("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"), ("idefics2", "Idefics2ForConditionalGeneration"),
...@@ -535,6 +537,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict( ...@@ -535,6 +537,7 @@ MODEL_FOR_IMAGE_MAPPING_NAMES = OrderedDict(
("efficientnet", "EfficientNetModel"), ("efficientnet", "EfficientNetModel"),
("focalnet", "FocalNetModel"), ("focalnet", "FocalNetModel"),
("glpn", "GLPNModel"), ("glpn", "GLPNModel"),
("hiera", "HieraModel"),
("imagegpt", "ImageGPTModel"), ("imagegpt", "ImageGPTModel"),
("levit", "LevitModel"), ("levit", "LevitModel"),
("mobilenet_v1", "MobileNetV1Model"), ("mobilenet_v1", "MobileNetV1Model"),
...@@ -610,6 +613,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( ...@@ -610,6 +613,7 @@ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
), ),
("efficientnet", "EfficientNetForImageClassification"), ("efficientnet", "EfficientNetForImageClassification"),
("focalnet", "FocalNetForImageClassification"), ("focalnet", "FocalNetForImageClassification"),
("hiera", "HieraForImageClassification"),
("imagegpt", "ImageGPTForImageClassification"), ("imagegpt", "ImageGPTForImageClassification"),
( (
"levit", "levit",
...@@ -1258,6 +1262,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict( ...@@ -1258,6 +1262,7 @@ MODEL_FOR_BACKBONE_MAPPING_NAMES = OrderedDict(
("dinat", "DinatBackbone"), ("dinat", "DinatBackbone"),
("dinov2", "Dinov2Backbone"), ("dinov2", "Dinov2Backbone"),
("focalnet", "FocalNetBackbone"), ("focalnet", "FocalNetBackbone"),
("hiera", "HieraBackbone"),
("maskformer-swin", "MaskFormerSwinBackbone"), ("maskformer-swin", "MaskFormerSwinBackbone"),
("nat", "NatBackbone"), ("nat", "NatBackbone"),
("pvt_v2", "PvtV2Backbone"), ("pvt_v2", "PvtV2Backbone"),
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)
_import_structure = {"configuration_hiera": ["HieraConfig"]}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_hiera"] = [
"HieraForImageClassification",
"HieraForPreTraining",
"HieraBackbone",
"HieraModel",
"HieraPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_hiera import HieraConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_hiera import (
HieraBackbone,
HieraForImageClassification,
HieraForPreTraining,
HieraModel,
HieraPreTrainedModel,
)
else:
import sys
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hiera model configuration"""
from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...utils.backbone_utils import BackboneConfigMixin, get_aligned_output_features_output_indices
logger = logging.get_logger(__name__)
class HieraConfig(BackboneConfigMixin, PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`HieraModel`]. It is used to instantiate a Hiera
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Hiera
[facebook/hiera-base-224](https://huggingface.co/facebook/hiera-base-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
embed_dim (`int`, *optional*, defaults to 96):
Dimensionality of patch embedding.
image_size (`list(int)`, *optional*, defaults to `[224, 224]`):
The size (resolution) of input in the format (height, width) for images
and (frames, height, width) for videos.
patch_size (`list(int)`, *optional*, defaults to `[7, 7]`):
The size (resolution) of each patch.
patch_stride (`list(int)`, *optional*, defaults to `[4, 4]`):
The stride of the patch.
patch_padding (`list(int)`, *optional*, defaults to `[3, 3]`):
The padding of the patch.
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of mlp hidden dim to embedding dim.
depths (`list(int)`, *optional*, defaults to `[2, 3, 16, 3]`):
Depth of each layer in the Transformer encoder.
num_heads (`list(int)`, *optional*, defaults to `[1, 2, 4, 8]`):
Number of attention heads in each layer of the Transformer encoder.
embed_dim_multiplier (`float`, *optional*, defaults to 2.0):
The multiplier to the dimensionality of patch embedding in each layer of the Transformer encoder.
num_query_pool (`int`, *optional*, defaults to 3):
The number of query pool stages.
query_stride (`list(int)`, *optional*, defaults to `[2, 2]`):
The stride of the query pool.
masked_unit_size (`list(int)`, *optional*, defaults to `[8, 8]`):
The size of the masked unit.
masked_unit_attention (`list(bool)`, *optional*, defaults to `[True, True, False, False]`):
Whether to use masked unit attention in each layer of the Transformer encoder.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The drop path rate.
num_channels (`int`, *optional*, defaults to 3):
The number of input channels.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder. If string, `"gelu"`, `"relu"`,
`"selu"` and `"gelu_new"` are supported.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices and
the zero_initializer for initializing all bias vectors.
layer_norm_init (`float`, *optional*, defaults to 1.0):
The initial weight value for layer normalization layers.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
decoder_hidden_size (`int`, *optional*):
Dimensionality of decoder embeddings for MAE pretraining.
decoder_depth (`int`, *optional*):
Depth of the decoder for MAE pretraining.
decoder_num_heads (`int`, *optional*):
Number of attention heads in each layer of the decoder for MAE pretraining.
normalize_pixel_loss (`bool`, *optional*, defaults to `True`):
Whether to normalize the pixel loss by the number of pixels.
mask_ratio (`float`, *optional*, defaults to 0.6):
The ratio of masked tokens in the input.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
out_indices (`List[int]`, *optional*):
If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
If unset and `out_features` is unset, will default to the last stage. Must be in the
same order as defined in the `stage_names` attribute.
Example:
```python
>>> from transformers import HieraConfig, HieraModel
>>> # Initializing a Hiera hiera-base-patch16-224 style configuration
>>> configuration = HieraConfig()
>>> # Initializing a model (with random weights) from the hiera-base-patch16-224 style configuration
>>> model = HieraModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "hiera"
attribute_map = {"num_hidden_layers": "num_layers"}
def __init__(
self,
embed_dim=96,
image_size=[224, 224],
patch_size=[7, 7],
patch_stride=[4, 4],
patch_padding=[3, 3],
mlp_ratio=4.0,
depths=[2, 3, 16, 3],
num_heads=[1, 2, 4, 8],
embed_dim_multiplier=2.0,
num_query_pool=3,
query_stride=[2, 2],
masked_unit_size=[8, 8],
masked_unit_attention=[True, True, False, False],
drop_path_rate=0.0,
num_channels=3,
hidden_act="gelu",
initializer_range=0.02,
layer_norm_init=1.0,
layer_norm_eps=1e-6,
decoder_hidden_size=None,
decoder_depth=None,
decoder_num_heads=None,
normalize_pixel_loss=True,
mask_ratio=0.6,
out_features=None,
out_indices=None,
**kwargs,
):
super().__init__(**kwargs)
if masked_unit_size[0] % query_stride[0] ** (len(depths) - 1) != 0:
raise ValueError(
f"masked_unit_size[0] ({masked_unit_size[0]}) must be divisible by query_stride[0] ({query_stride[0]}) "
f"raised to the power of the number of layers ({len(depths) - 1})"
)
if num_query_pool >= len(depths):
raise ValueError(
f"num_query_pool ({num_query_pool}) must be less than the number of layers ({len(depths)})"
)
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.patch_stride = patch_stride
self.patch_padding = patch_padding
self.mlp_ratio = mlp_ratio
self.depths = depths
self.num_heads = num_heads
self.num_layers = len(depths)
self.embed_dim_multiplier = embed_dim_multiplier
self.num_query_pool = num_query_pool
self.query_stride = query_stride
self.masked_unit_size = masked_unit_size
self.masked_unit_attention = masked_unit_attention
self.drop_path_rate = drop_path_rate
self.num_channels = num_channels
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.layer_norm_init = layer_norm_init
self.layer_norm_eps = layer_norm_eps
self.decoder_hidden_size = decoder_hidden_size
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.normalize_pixel_loss = normalize_pixel_loss
self.mask_ratio = mask_ratio
# we set the hidden_size attribute in order to make Hiera work with VisionEncoderDecoderModel
# this indicates the channel dimension after the last stage of the model
self.hidden_size = int(embed_dim * embed_dim_multiplier ** (len(depths) - 1))
self.stage_names = ["stem"] + [f"stage{idx}" for idx in range(1, len(depths) + 1)]
self._out_features, self._out_indices = get_aligned_output_features_output_indices(
out_features=out_features, out_indices=out_indices, stage_names=self.stage_names
)
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Hiera checkpoints from the original repository.
URL: https://github.com/facebookresearch/hiera
"""
import argparse
import json
import math
from typing import Dict, Tuple
import requests
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import BitImageProcessor, HieraConfig, HieraForImageClassification, HieraForPreTraining, HieraModel
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from transformers.utils import logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config: HieraConfig, base_model: bool, mae_model: bool):
rename_keys = []
# fmt: off
num_stages = len(config.depths)
# embedding dimensions for input and stages
dims = [config.embed_dim] + [int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(num_stages)]
global_layer_idx = 0
for stage_idx in range(num_stages):
dim_in = dims[stage_idx]
dim_out = dims[stage_idx + 1]
for layer_idx in range(config.depths[stage_idx]):
rename_keys.append((f"blocks.{global_layer_idx}.norm1.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.layernorm_before.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.norm1.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.layernorm_before.bias"))
rename_keys.append((f"blocks.{global_layer_idx}.attn.qkv.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.attn.qkv.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.attn.qkv.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.attn.qkv.bias"))
rename_keys.append((f"blocks.{global_layer_idx}.attn.proj.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.attn.proj.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.attn.proj.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.attn.proj.bias"))
rename_keys.append((f"blocks.{global_layer_idx}.norm2.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.layernorm_after.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.norm2.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.layernorm_after.bias"))
rename_keys.append((f"blocks.{global_layer_idx}.mlp.fc1.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.mlp.fc1.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.mlp.fc1.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.mlp.fc1.bias"))
rename_keys.append((f"blocks.{global_layer_idx}.mlp.fc2.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.mlp.fc2.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.mlp.fc2.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.mlp.fc2.bias"))
# projection layer only for the first layer of each stage boundary (except the first stage)
if dim_out != dim_in and layer_idx == 0:
rename_keys.append((f"blocks.{global_layer_idx}.proj.weight", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.proj.weight"))
rename_keys.append((f"blocks.{global_layer_idx}.proj.bias", f"hiera.encoder.stages.{stage_idx}.layers.{layer_idx}.proj.bias"))
global_layer_idx += 1
# projection layer + position embeddings
rename_keys.extend(
[
("patch_embed.proj.weight", "hiera.embeddings.patch_embeddings.projection.weight"),
("patch_embed.proj.bias", "hiera.embeddings.patch_embeddings.projection.bias")
]
)
rename_keys.append(("pos_embed", "hiera.embeddings.position_embeddings"))
if base_model:
# layernorm + pooler
rename_keys.extend([("norm.weight", "pooler.layernorm.weight"), ("norm.bias", "pooler.layernorm.bias")])
# if just the base model, we should remove "hiera" from all keys that start with "hiera"
rename_keys = [(pair[0], pair[1][6:]) if pair[1].startswith("hiera") else pair for pair in rename_keys]
elif mae_model:
rename_keys.extend(
[
("encoder_norm.weight", "encoder_norm.weight"),
("encoder_norm.bias", "encoder_norm.bias"),
("mask_token", "decoder.mask_token"),
("decoder_pos_embed", "decoder.decoder_position_embeddings"),
("decoder_norm.weight", "decoder.decoder_norm.weight"),
("decoder_norm.bias", "decoder.decoder_norm.bias"),
("decoder_pred.weight", "decoder.decoder_pred.weight"),
("decoder_pred.bias", "decoder.decoder_pred.bias"),
("decoder_embed.weight", "decoder.decoder_embeddings.weight"),
("decoder_embed.bias", "decoder.decoder_embeddings.bias")
]
)
for i in range(config.decoder_depth):
rename_keys.extend(
[
(f"decoder_blocks.{i}.norm1.weight", f"decoder.decoder_block.layers.{i}.layernorm_before.weight"),
(f"decoder_blocks.{i}.norm1.bias", f"decoder.decoder_block.layers.{i}.layernorm_before.bias"),
(f"decoder_blocks.{i}.attn.qkv.weight", f"decoder.decoder_block.layers.{i}.attn.qkv.weight"),
(f"decoder_blocks.{i}.attn.qkv.bias", f"decoder.decoder_block.layers.{i}.attn.qkv.bias"),
(f"decoder_blocks.{i}.attn.proj.weight", f"decoder.decoder_block.layers.{i}.attn.proj.weight"),
(f"decoder_blocks.{i}.attn.proj.bias", f"decoder.decoder_block.layers.{i}.attn.proj.bias"),
(f"decoder_blocks.{i}.norm2.weight", f"decoder.decoder_block.layers.{i}.layernorm_after.weight"),
(f"decoder_blocks.{i}.norm2.bias", f"decoder.decoder_block.layers.{i}.layernorm_after.bias"),
(f"decoder_blocks.{i}.mlp.fc1.weight", f"decoder.decoder_block.layers.{i}.mlp.fc1.weight"),
(f"decoder_blocks.{i}.mlp.fc1.bias", f"decoder.decoder_block.layers.{i}.mlp.fc1.bias"),
(f"decoder_blocks.{i}.mlp.fc2.weight", f"decoder.decoder_block.layers.{i}.mlp.fc2.weight"),
(f"decoder_blocks.{i}.mlp.fc2.bias", f"decoder.decoder_block.layers.{i}.mlp.fc2.bias"),
]
)
for i in range(config.num_query_pool):
rename_keys.extend(
[
(f"multi_scale_fusion_heads.{i}.weight", f"multiscale_fusion.multi_scale_fusion_heads.{i}.weight"),
(f"multi_scale_fusion_heads.{i}.bias", f"multiscale_fusion.multi_scale_fusion_heads.{i}.bias")
]
)
else:
# layernorm + classification head
rename_keys.extend(
[
("norm.weight", "hiera.pooler.layernorm.weight"),
("norm.bias", "hiera.pooler.layernorm.bias"),
("head.projection.weight", "classifier.weight"),
("head.projection.bias", "classifier.bias"),
]
)
# fmt: on
return rename_keys
def remove_classification_head_(state_dict):
ignore_keys = ["head.projection.weight", "head.projection.bias"]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_key(dct, old, new):
val = dct.pop(old)
dct[new] = val
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
im = Image.open(requests.get(url, stream=True).raw)
return im
def get_labels_for_classifier(model_name: str) -> Tuple[Dict[int, str], Dict[str, int], int]:
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
num_labels = len(id2label)
return id2label, label2id, num_labels
def get_hiera_config(model_name: str, base_model: bool, mae_model: bool) -> HieraConfig:
if model_name == "hiera-tiny-224":
config = HieraConfig(depths=[1, 2, 7, 2])
elif model_name == "hiera-small-224":
config = HieraConfig(depths=[1, 2, 11, 2])
elif model_name == "hiera-base-224":
config = HieraConfig()
elif model_name == "hiera-base-plus-224":
config = HieraConfig(embed_dim=112, num_heads=[2, 4, 8, 16])
elif model_name == "hiera-large-224":
config = HieraConfig(embed_dim=144, num_heads=[2, 4, 8, 16], depths=[2, 6, 36, 4])
elif model_name == "hiera-huge-224":
config = HieraConfig(embed_dim=256, num_heads=[4, 8, 16, 32], depths=[2, 6, 36, 4])
else:
raise ValueError(f"Unrecognized model name: {model_name}")
if base_model:
pass
elif mae_model:
config.num_query_pool = 2
config.decoder_hidden_size = 512
config.decoder_depth = 8
config.decoder_num_heads = 16
# Table 3b from Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
config.mask_ratio = 0.6
else:
id2label, label2id, num_labels = get_labels_for_classifier(model_name)
config.id2label = id2label
config.label2id = label2id
config.num_labels = num_labels
return config
@torch.no_grad()
def convert_hiera_checkpoint(args):
model_name = args.model_name
base_model = args.base_model
pytorch_dump_folder_path = args.pytorch_dump_folder_path
push_to_hub = args.push_to_hub
mae_model = args.mae_model
config = get_hiera_config(model_name, base_model, mae_model)
# Load original hiera model
original_model_name = model_name.replace("-", "_")
original_model_name = f"mae_{original_model_name}" if mae_model else original_model_name
original_checkpoint_name = "mae_in1k_ft_in1k" if not (base_model or mae_model) else "mae_in1k"
original_model = torch.hub.load(
"facebookresearch/hiera",
model=original_model_name,
pretrained=True,
checkpoint=original_checkpoint_name,
)
original_model.eval()
original_state_dict = original_model.state_dict()
# Don't need to remove head for MAE because original implementation doesn't have it on MAE
if base_model:
remove_classification_head_(original_state_dict)
# # Rename keys
new_state_dict = original_state_dict.copy()
rename_keys = create_rename_keys(config, base_model, mae_model)
for src, dest in rename_keys:
rename_key(new_state_dict, src, dest)
# Load HF hiera model
if base_model:
model = HieraModel(config)
elif mae_model:
model = HieraForPreTraining(config)
else:
model = HieraForImageClassification(config)
model.eval()
missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
input_image = prepare_img()
original_image_preprocessor = transforms.Compose(
[
transforms.Resize(int((256 / 224) * 224), interpolation=transforms.functional.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
]
)
image_processor = BitImageProcessor(
image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD, size={"shortest_edge": 256}
)
inputs = image_processor(images=input_image, return_tensors="pt")
expected_pixel_values = original_image_preprocessor(input_image).unsqueeze(0)
input_image = prepare_img()
inputs = image_processor(images=input_image, return_tensors="pt")
expected_pixel_values = original_image_preprocessor(input_image).unsqueeze(0)
assert torch.allclose(inputs.pixel_values, expected_pixel_values, atol=1e-4)
print("Pixel values look good!")
print(f"{inputs.pixel_values[0, :3, :3, :3]=}")
# If is MAE we pass a noise to generate a random mask
mask_spatial_shape = [
i // s // ms for i, s, ms in zip(config.image_size, config.patch_stride, config.masked_unit_size)
]
num_windows = math.prod(mask_spatial_shape)
torch.manual_seed(2)
noise = torch.rand(1, num_windows)
outputs = model(**inputs) if not mae_model else model(noise=noise, **inputs)
# original implementation returns logits.softmax(dim=-1)
if base_model:
expected_prob, expected_intermediates = original_model(expected_pixel_values, return_intermediates=True)
expected_last_hidden = expected_intermediates[-1]
batch_size, _, _, hidden_dim = expected_last_hidden.shape
expected_last_hidden = expected_last_hidden.reshape(batch_size, -1, hidden_dim)
assert torch.allclose(outputs.last_hidden_state, expected_last_hidden, atol=1e-3)
print("Base Model looks good as hidden states match original implementation!")
print(f"{outputs.last_hidden_state[0, :3, :3]=}")
elif mae_model:
# get mask from noise to be able to compare outputs
mask, _ = model.hiera.embeddings.patch_embeddings.random_masking(expected_pixel_values, noise)
expected_loss, _, _, _ = original_model(expected_pixel_values, mask=mask.bool())
assert torch.allclose(outputs.loss, expected_loss, atol=1e-3)
print("MAE Model looks good as loss matches original implementation!")
else:
expected_prob = original_model(expected_pixel_values)
assert torch.allclose(outputs.logits.softmax(dim=-1), expected_prob, atol=1e-3)
print("Classifier looks good as probs match original implementation")
print(f"{outputs.logits[:, :5]=}")
if pytorch_dump_folder_path is not None:
print(f"Saving model and processor for {model_name} to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
image_processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
hub_name = model_name
if base_model:
hub_name = model_name
elif mae_model:
hub_name = f"{model_name}-mae"
else:
hub_name = f"{model_name}-in1k"
repo_id = f"EduardoPacheco/{hub_name}"
print(f"Pushing model and processor for {model_name} to hub at {repo_id}")
model.push_to_hub(repo_id)
image_processor.push_to_hub(repo_id)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model-name",
default="hiera-tiny-224",
type=str,
choices=[
"hiera-tiny-224",
"hiera-small-224",
"hiera-base-224",
"hiera-base-plus-224",
"hiera-large-224",
"hiera-huge-224",
],
help="Name of the Hiera model you'd like to convert.",
)
parser.add_argument(
"--pytorch-dump-folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--verify-logits",
action="store_true",
help="Whether or not to verify the logits against the original implementation.",
)
parser.add_argument(
"--push-to-hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
)
parser.add_argument(
"--base-model",
action="store_true",
help="Whether to only convert the base model (no projection head weights).",
)
parser.add_argument(
"--mae-model", action="store_true", help="Whether to convert to MAE checkpoint to HieraForPreTraining."
)
args = parser.parse_args()
convert_hiera_checkpoint(args)
# coding=utf-8
# Copyright 2024 Meta and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Hiera model."""
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...modeling_outputs import (
BackboneOutput,
BaseModelOutput,
BaseModelOutputWithPooling,
ImageClassifierOutput,
ModelOutput,
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_hiera import HieraConfig
logger = logging.get_logger(__name__)
# General docstring
_CONFIG_FOR_DOC = "HieraConfig"
# Base docstring
_CHECKPOINT_FOR_DOC = "facebook/hiera-tiny-224-hf"
_EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
# Image classification docstring
_IMAGE_CLASS_CHECKPOINT = "facebook/hiera-tiny-224-in1k-hf"
_IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
@dataclass
class HieraEncoderOutput(ModelOutput):
"""
Hiera encoder's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Thesre are the unrolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class HieraModelOutput(ModelOutput):
"""
Hiera model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
Average pooling of the last layer hidden-state.
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (0) and which are not (1).
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Tensor containing the original index of the (shuffled) masked patches.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
last_hidden_state: torch.FloatTensor = None
pooler_output: Optional[torch.FloatTensor] = None
bool_masked_pos: torch.BoolTensor = None
ids_restore: torch.LongTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class HieraForImageClassificationOutput(ImageClassifierOutput):
"""
Hiera image classification outputs.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, `optional`):
Loss value for the training task.
logits (`torch.FloatTensor` of shape `(batch_size, num_labels)`):
Prediction scores of the classification head (logits of the output layer).
hidden_states (`tuple(torch.FloatTensor)`, `optional`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. These are the unrolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, `optional`):
Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, `optional`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, height, width, hidden_size)`. These are the reshaped and re-rolled hidden states of the model.
Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class HieraForPreTrainingOutput(ModelOutput):
"""
Class for HieraForPreTraining's outputs, with potential hidden states and attentions.
Args:
loss (`torch.FloatTensor` of shape `(1,)`):
Pixel reconstruction loss.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (0) and which are not (1).
ids_restore (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Tensor containing the original index of the (shuffled) masked patches.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, height, width, hidden_size)`. Hidden-states of the model at the output of each layer
plus the initial embedding outputs reshaped to include the spatial dimensions.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
bool_masked_pos: torch.BoolTensor = None
ids_restore: torch.LongTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
class HieraPatchEmbeddings(nn.Module):
"""
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
Transformer.
"""
def __init__(self, config, is_mae: bool = False):
super().__init__()
# Support any number of spatial dimensions
self.spatial_dims = len(config.patch_size)
if self.spatial_dims != 2:
raise ValueError(f"The number of dimensions of the input image should be 2, but got {self.spatial_dims}.")
self.num_channels = config.num_channels
self.image_size = config.image_size[-2:]
self.tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, config.masked_unit_size)]
self.mask_ratio = config.mask_ratio
self.is_mae = is_mae
self.projection = nn.Conv2d(
self.num_channels,
config.embed_dim,
kernel_size=config.patch_size,
stride=config.patch_stride,
padding=config.patch_padding,
)
def masked_conv(
self, pixel_values: torch.FloatTensor, bool_masked_pos: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
"""Zero-out the masked regions of the input before conv.
Prevents leakage of masked regions when using overlapping kernels.
"""
if bool_masked_pos is None:
return self.projection(pixel_values)
target_size = pixel_values.shape[2:]
# Reshape bool_masked_pos to (batch_size, 1, mask_unit_height, mask_unit_width)
bool_masked_pos = bool_masked_pos.view(pixel_values.shape[0], 1, *self.mask_spatial_shape)
bool_masked_pos = nn.functional.interpolate(bool_masked_pos.float(), size=target_size)
return self.projection(pixel_values * bool_masked_pos)
def random_masking(
self, pixel_values: torch.FloatTensor, noise: Optional[torch.FloatTensor] = None
) -> Tuple[torch.BoolTensor, torch.LongTensor]:
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
noise.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`)
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
"""
batch_size = pixel_values.shape[0]
# Tokens selected for masking at mask unit level
num_windows = math.prod(self.mask_spatial_shape)
len_keep = int(num_windows * (1 - self.mask_ratio))
if noise is None:
noise = torch.rand(batch_size, num_windows, device=pixel_values.device)
# Sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)
# ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1).to(pixel_values.device)
# Generate the binary bool_masked_pos: 1 is *keep*, 0 is *remove*
# Note this is opposite to original MAE
bool_masked_pos = torch.zeros([batch_size, num_windows], device=pixel_values.device)
bool_masked_pos[:, :len_keep] = 1
# Unshuffle to get the binary bool_masked_pos
bool_masked_pos = torch.gather(bool_masked_pos, dim=1, index=ids_restore).bool()
return bool_masked_pos, ids_restore
def forward(
self,
pixel_values: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
(bool_masked_pos, ids_restore) = (
self.random_masking(pixel_values, noise=noise) if self.is_mae else (None, None)
)
embeddings = self.masked_conv(pixel_values, bool_masked_pos)
embeddings = embeddings.flatten(2).transpose(2, 1)
return embeddings, bool_masked_pos, ids_restore
class HieraEmbeddings(nn.Module):
"""
Construct position and patch embeddings.
"""
def __init__(self, config: HieraConfig, is_mae: bool = False) -> None:
super().__init__()
self.patch_stride = config.patch_stride
tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
self.mask_spatial_shape = [i // s for i, s in zip(tokens_spatial_shape, config.masked_unit_size)]
self.num_tokens = math.prod(tokens_spatial_shape)
self.is_mae = is_mae
self.patch_embeddings = HieraPatchEmbeddings(config, is_mae=is_mae)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.num_tokens, config.embed_dim))
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, pos_embeds: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Adapted from:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1]
num_positions = pos_embeds.shape[1]
if num_patches == num_positions and height == width:
return pos_embeds
dim = embeddings.shape[-1]
h0 = height // self.patch_stride[0]
w0 = width // self.patch_stride[1]
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
pos_embeds = pos_embeds.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
pos_embeds = pos_embeds.permute(0, 3, 1, 2)
pos_embeds = nn.functional.interpolate(
pos_embeds,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h0) != pos_embeds.shape[-2] or int(w0) != pos_embeds.shape[-1]:
raise ValueError("The interpolated position encoding does not have the right size")
pos_embeds = pos_embeds.permute(0, 2, 3, 1).view(1, -1, dim)
return pos_embeds
def get_position_embedding(
self, embeddings: torch.Tensor, height: int, width: int, interpolate_pos_encoding: bool
) -> torch.FloatTensor:
position_embeddings = self.position_embeddings
position_embeddings = (
self.interpolate_pos_encoding(embeddings, position_embeddings, height, width)
if interpolate_pos_encoding
else position_embeddings
)
return position_embeddings
def forward(
self,
pixel_values: torch.FloatTensor,
noise: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.BoolTensor], Optional[torch.LongTensor]]:
height, width = pixel_values.shape[-2:]
embeddings, bool_masked_pos, ids_restore = self.patch_embeddings(pixel_values, noise=noise)
embeddings = embeddings + self.get_position_embedding(embeddings, height, width, interpolate_pos_encoding)
return embeddings, bool_masked_pos, ids_restore
class HieraMaskUnitAttention(nn.Module):
"""
Computes either Mask Unit or Global Attention. Also is able to perform query pooling.
Note: this assumes the tokens have already been flattened and unrolled into mask units.
"""
def __init__(
self,
hidden_size: int,
hidden_size_output: int,
num_heads: int,
query_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
) -> None:
super().__init__()
self.num_heads = num_heads
self.query_stride = query_stride
self.hidden_size_output = hidden_size_output
self.head_dim = hidden_size_output // num_heads
self.scale = (self.head_dim) ** -0.5
self.qkv = nn.Linear(hidden_size, 3 * hidden_size_output)
self.proj = nn.Linear(hidden_size_output, hidden_size_output)
self.window_size = window_size
self.use_mask_unit_attn = use_mask_unit_attn
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input should be of shape [batch, tokens, channels]."""
batch_size, seq_len, _ = hidden_states.shape
num_windows = 1
if self.use_mask_unit_attn:
num_windows = seq_len // (self.query_stride * self.window_size)
qkv = self.qkv(hidden_states)
qkv = qkv.reshape(batch_size, -1, num_windows, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(3, 0, 4, 2, 1, 5)
query, key, value = qkv.unbind(0)
if self.query_stride > 1:
# Refer to unroll to see how this performs a maxpool-Nd
query = query.view(batch_size, self.num_heads, num_windows, self.query_stride, -1, self.head_dim)
query = query.max(dim=3).values
attn_weights = (query * self.scale) @ key.transpose(-1, -2)
attn_weights = attn_weights.softmax(dim=-1)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = attn_weights @ value
attn_output = attn_output.transpose(1, 3).reshape(batch_size, -1, self.hidden_size_output)
attn_output = self.proj(attn_output)
return (attn_output, attn_weights) if output_attentions else (attn_output, None)
# Copied from transformers.models.beit.modeling_beit.drop_path
def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
argument.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
random_tensor.floor_() # binarize
output = input.div(keep_prob) * random_tensor
return output
# Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Hiera
class HieraDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: Optional[float] = None) -> None:
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class HieraMlp(nn.Module):
def __init__(self, config, dim: int) -> None:
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(dim, int(dim * config.mlp_ratio))
self.fc2 = nn.Linear(int(dim * config.mlp_ratio), dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class HieraLayer(nn.Module):
def __init__(
self,
config,
hidden_size: int,
hidden_size_output: int,
num_heads: int,
drop_path: float = 0.0,
query_stride: int = 1,
window_size: int = 0,
use_mask_unit_attn: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.hidden_size_output = hidden_size_output
self.query_stride = query_stride
self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
self.attn = HieraMaskUnitAttention(
hidden_size=hidden_size,
hidden_size_output=hidden_size_output,
num_heads=num_heads,
query_stride=query_stride,
window_size=window_size,
use_mask_unit_attn=use_mask_unit_attn,
)
self.layernorm_after = nn.LayerNorm(hidden_size_output, eps=config.layer_norm_eps)
self.mlp = HieraMlp(config, hidden_size_output)
self.drop_path = HieraDropPath(drop_path) if drop_path > 0 else nn.Identity()
if hidden_size != hidden_size_output:
self.proj = nn.Linear(hidden_size, hidden_size_output)
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
batch_size, seq_len, _ = hidden_states.shape
# Attention + Q Pooling
hidden_states_norm = self.layernorm_before(hidden_states)
if self.hidden_size != self.hidden_size_output:
hidden_states = self.proj(hidden_states_norm)
# Refer to unroll to see how this performs a maxpool-Nd
hidden_states = (
hidden_states.view(batch_size, self.query_stride, -1, self.hidden_size_output).max(dim=1).values
)
(hidden_states_norm, attn_weights) = self.attn(
hidden_states_norm, head_mask, output_attentions=output_attentions
)
hidden_states = hidden_states + self.drop_path(hidden_states_norm)
residual = hidden_states
hidden_states = self.layernorm_after(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.drop_path(hidden_states)
return (hidden_states, attn_weights)
class HieraStage(nn.Module):
def __init__(
self,
config,
depth: int,
hidden_size: int,
hidden_size_output: int,
num_heads: int,
drop_path: List[float],
query_stride: List[int],
window_size: int,
use_mask_unit_attn: bool,
stage_num: Optional[int] = None,
) -> None:
super().__init__()
# we need to know if the previous stage used masked attention
# mask unit or global attention.
# lag by 1 layer, so that global attention,
# applied post pooling on lower resolution
previous_stage_used_masked_attention = False
if stage_num is not None:
previous_stage_used_masked_attention = config.masked_unit_attention[stage_num - 1 if stage_num > 0 else 0]
self.layers = nn.ModuleList(
[
HieraLayer(
config=config,
hidden_size=hidden_size if i == 0 else hidden_size_output,
hidden_size_output=hidden_size_output,
num_heads=num_heads,
drop_path=drop_path[i],
query_stride=query_stride[i],
window_size=window_size,
use_mask_unit_attn=use_mask_unit_attn or (previous_stage_used_masked_attention and i == 0),
)
for i in range(depth)
]
)
def forward(
self, hidden_states: torch.Tensor, head_mask: Optional[torch.FloatTensor], output_attentions: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
(hidden_states, attn_weights) = layer_module(
hidden_states, layer_head_mask, output_attentions=output_attentions
)
return hidden_states, attn_weights
def undo_windowing(hidden_states: torch.Tensor, shape: List[int], mask_unit_shape: List[int]) -> torch.Tensor:
"""
Restore spatial organization by undoing windowed organization of mask units.
Args:
hidden_states (`torch.Tensor`): The hidden states tensor of shape `[batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]`.
shape (`List[int]`): The original shape of the hidden states tensor before windowing.
mask_unit_shape (`List[int]`): The shape of the mask units used for windowing.
Returns:
torch.Tensor: The restored hidden states tensor of shape [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size].
"""
batch_size, hidden_size = hidden_states.shape[0], hidden_states.shape[-1]
# From: [batch_size, num_mask_unit_height*num_mask_unit_width, hidden_size]
# To: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
num_mask_units = [s // mu for s, mu in zip(shape, mask_unit_shape)]
hidden_states = hidden_states.view(batch_size, *num_mask_units, *mask_unit_shape, hidden_size)
# From: [batch_size, num_mask_unit_height, num_mask_unit_width, mask_unit_height, mask_unit_width, hidden_size]
# To: [batch_size, num_mask_unit_height*mask_unit_height, num_mask_unit_width*mask_unit_width, hidden_size]
hidden_states = hidden_states.permute(0, 1, 3, 2, 4, 5)
hidden_states = hidden_states.reshape(batch_size, *shape, hidden_size)
return hidden_states
class HieraEncoder(nn.Module):
def __init__(self, config: HieraConfig) -> None:
super().__init__()
total_depth = sum(config.depths)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, total_depth)]
# query strides rule
cumulative_depths = torch.tensor(config.depths).cumsum(0).tolist()
query_pool_layer = cumulative_depths[: config.num_query_pool]
query_strides = [math.prod(config.query_stride) if i in query_pool_layer else 1 for i in range(total_depth)]
# Transformer blocks
self.stages = nn.ModuleList()
hidden_size = config.embed_dim
stage_ends = [0] + cumulative_depths
masked_unit_area = math.prod(config.masked_unit_size)
query_stride_area = math.prod(config.query_stride)
for idx_stage, depth in enumerate(config.depths):
hidden_size_output = int(config.embed_dim * config.embed_dim_multiplier**idx_stage)
stage = HieraStage(
config=config,
depth=depth,
hidden_size=hidden_size,
hidden_size_output=hidden_size_output,
num_heads=config.num_heads[idx_stage],
drop_path=dpr[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
query_stride=query_strides[stage_ends[idx_stage] : stage_ends[idx_stage + 1]],
window_size=int(masked_unit_area * query_stride_area**-idx_stage),
use_mask_unit_attn=config.masked_unit_attention[idx_stage],
stage_num=idx_stage,
)
hidden_size = hidden_size_output
self.stages.append(stage)
# Setting reroll schedule
# The first stage has to reverse everything
# The next stage has to reverse all but the first unroll, etc.
stage_size = [i // s for i, s in zip(config.image_size, config.patch_stride)]
unroll_schedule = [config.query_stride] * len(config.depths[:-1])
self.schedule = {}
for idx_stage in range(len(config.depths)):
self.schedule[idx_stage] = unroll_schedule, stage_size
if idx_stage < config.num_query_pool:
stage_size = [i // s for i, s in zip(stage_size, config.query_stride)]
unroll_schedule = unroll_schedule[1:]
self.gradient_checkpointing = False
def reroll(
self, hidden_states: torch.Tensor, stage_idx: int, bool_masked_pos: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
"""
Roll the given tensor back up to spatial order assuming it's from the given block.
If no bool_masked_pos is provided returns:
- [batch_size, height, width, hidden_size]
If a bool_masked_pos is provided returns:
- [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
"""
schedule, size = self.schedule[stage_idx]
batch_size, seq_len, hidden_size = hidden_states.shape
num_dim = len(size)
mask_unit_shape = [1] * num_dim
for strides in schedule:
# Extract the current patch from seq_len
hidden_states = hidden_states.view(
batch_size, *strides, seq_len // math.prod(strides), *mask_unit_shape, hidden_size
)
# Move that patch into the current MU
# Input: [batch_size, stride, stride, seq_len//(stride*stride), mask_unit_height, mask_unit_width, hidden_size]
# Output: [batch_size, seq_len//(stride*stride), stride, mask_unit_height, stride, mask_unit_width, hidden_size]
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5, 6)
# Reshape to [batch_size, seq_len//(stride*stride), *mask_units, hidden_size]
for i in range(num_dim):
mask_unit_shape[i] *= strides[i]
hidden_states = hidden_states.reshape(batch_size, -1, *mask_unit_shape, hidden_size)
seq_len = hidden_states.shape[1]
# Current shape (e.g., 2d: [batch_size, #num_mask_units_height*#num_mask_units_width, mask_unit_height, mask_unit_width, hidden_size])
hidden_states = hidden_states.view(batch_size, seq_len, *mask_unit_shape, hidden_size)
# If masked, return [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
if bool_masked_pos is not None:
return hidden_states
# If not masked, we can return [batch_size, height, width, hidden_size]
hidden_states = undo_windowing(hidden_states, size, mask_unit_shape)
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: bool = False,
output_hidden_states: bool = False,
return_dict: bool = True,
) -> Union[tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_reshaped_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
reshaped_hidden_states = self.reroll(hidden_states, stage_idx=0, bool_masked_pos=bool_masked_pos)
all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
for i, stage_module in enumerate(self.stages):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
stage_module.__call__, hidden_states, layer_head_mask, output_attentions
)
else:
layer_outputs = stage_module(hidden_states, layer_head_mask, output_attentions)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
reshaped_hidden_states = self.reroll(hidden_states, stage_idx=i, bool_masked_pos=bool_masked_pos)
all_reshaped_hidden_states = all_reshaped_hidden_states + (reshaped_hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
if v is not None
)
return HieraEncoderOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
reshaped_hidden_states=all_reshaped_hidden_states,
)
def unroll(
hidden_states: torch.Tensor, image_shape: Tuple[int, int], patch_stride: Tuple[int, int], schedule: List[List[int]]
) -> torch.Tensor:
"""
Reorders the tokens such that patches are contiguous in memory.
E.g., given [batch_size, (height, width), hidden_size] and stride of (stride, stride), this will re-order the tokens as
[batch_size, (stride, stride, height // stride, width // stride), hidden_size]
This allows operations like Max2d to be computed as x.view(batch_size, stride*stride, -1, hidden_size).max(dim=1).
Not only is this faster, but it also makes it easy to support inputs of arbitrary
dimensions in addition to patch-wise sparsity.
Performing this operation multiple times in sequence puts entire windows as contiguous
in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
size 8x8 would be contiguous in memory, allowing operations like mask unit attention
computed easily and efficiently, while also allowing max to be applied sequentially.
Note: This means that intermediate values of the model are not in height x width order, so they
need to be re-rolled if you want to use the intermediate values as a height x width feature map.
The last block of the network is fine though, since by then the strides are all consumed.
"""
batch_size, _, hidden_size = hidden_states.shape
size = [i // s for i, s in zip(image_shape, patch_stride)]
current_size = size
hidden_states = hidden_states.view(*([batch_size] + current_size + [hidden_size]))
for strides in schedule:
# Move patches with the given strides to the batch dimension
# Create a view of the tensor with the patch stride as separate dims
# For example in 2d: [batch_size, height // stride, stride, width // stride, stride, C]
current_size = [i // s for i, s in zip(current_size, strides)]
# initialize new_shape with [height // stride, stride, width // stride, stride]
new_shape = [item for pair in zip(current_size, strides) for item in pair]
# add batch_size and hidden_size to new_shape
new_shape = [batch_size] + new_shape + [hidden_size]
hidden_states = hidden_states.view(new_shape)
# Move the patch stride into the batch dimension
# For example in 2d: [batch_size, stride, stride, height // stride, width // stride, hidden_size]
num_dims = len(new_shape)
permute = [0] + list(range(2, num_dims - 1, 2)) + list(range(1, num_dims - 1, 2)) + [num_dims - 1]
hidden_states = hidden_states.permute(permute)
# Now finally flatten the relevant dims into the batch dimension
hidden_states = hidden_states.flatten(0, len(strides))
batch_size *= math.prod(strides)
hidden_states = hidden_states.reshape(-1, math.prod(size), hidden_size)
return hidden_states
class HieraPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = HieraConfig
base_model_prefix = "hiera"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _init_weights(self, module) -> None:
"""Initialize the weights"""
std = self.config.initializer_range
if isinstance(module, HieraEmbeddings):
nn.init.trunc_normal_(module.position_embeddings, std=std)
elif isinstance(module, HieraDecoder):
nn.init.trunc_normal_(module.mask_token, std=std)
nn.init.trunc_normal_(module.decoder_position_embeddings, std=std)
elif isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
nn.init.trunc_normal_(module.weight, std=std)
if module.bias is not None:
nn.init.constant_(module.bias, std)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, std)
nn.init.constant_(module.weight, self.config.layer_norm_init)
HIERA_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config ([`HieraConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
HIERA_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`BitImageProcessor.__call__`]
for details.
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class HieraPooler(nn.Module):
def __init__(self, config: HieraConfig):
super().__init__()
num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
self.layernorm = nn.LayerNorm(num_features, eps=config.layer_norm_eps)
self.pooler = nn.AdaptiveAvgPool1d(1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.transpose(1, 2)
pooled_output = self.pooler(hidden_states)
pooled_output = torch.flatten(pooled_output, 1)
pooled_output = self.layernorm(pooled_output)
return pooled_output
@add_start_docstrings(
"The bare Hiera Model transformer outputting raw hidden-states without any specific head on top.",
HIERA_START_DOCSTRING,
"""
add_pooling_layer (`bool`, *optional*, defaults to `True`):
Whether or not to apply pooling layer.
is_mae (`bool`, *optional*, defaults to `False`):
Whether or not to run the model on MAE mode.
""",
)
class HieraModel(HieraPreTrainedModel):
def __init__(self, config: HieraConfig, add_pooling_layer: bool = True, is_mae: bool = False):
super().__init__(config)
self.num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
self.embeddings = HieraEmbeddings(config, is_mae=is_mae)
self.encoder = HieraEncoder(config)
self.unroll_schedule = [config.query_stride] * len(config.depths[:-1])
self.pooler = HieraPooler(config) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> HieraPatchEmbeddings:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=HieraModelOutput,
config_class=_CONFIG_FOR_DOC,
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
noise: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
when is_mae is set to True.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, len(self.config.depths))
embedding_output, bool_masked_pos, ids_restore = self.embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, noise=noise
)
image_shape = (pixel_values.shape[-2], pixel_values.shape[-1])
hidden_states = unroll(
embedding_output,
image_shape=image_shape,
patch_stride=self.config.patch_stride,
schedule=self.unroll_schedule,
)
# Discard masked tokens if bool_masked_pos is provided
if bool_masked_pos is not None:
mask_unit_area = math.prod(self.config.masked_unit_size)
batch_size, _, hidden_size = hidden_states.shape
positions = bool_masked_pos.unsqueeze(-1).tile(1, mask_unit_area, hidden_size)
hidden_states = hidden_states[positions]
hidden_states = hidden_states.view(batch_size, -1, hidden_size)
encoder_outputs = self.encoder(
hidden_states,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output)
if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
head_outputs = (
head_outputs + (bool_masked_pos, ids_restore) if bool_masked_pos is not None else head_outputs
)
return head_outputs + encoder_outputs[1:]
return HieraModelOutput(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
bool_masked_pos=bool_masked_pos,
ids_restore=ids_restore,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
)
class HieraDecoder(nn.Module):
def __init__(self, config: HieraConfig):
super().__init__()
num_features = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
tokens_spatial_shape = [i // s for i, s in zip(config.image_size, config.patch_stride)]
self.tokens_spatial_shape_final = [
i // s ** (config.num_query_pool) for i, s in zip(tokens_spatial_shape, config.query_stride)
]
self.mask_unit_spatial_shape_final = [
i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
]
self.decoder_embeddings = nn.Linear(num_features, config.decoder_hidden_size)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
self.decoder_position_embeddings = nn.Parameter(
torch.zeros(1, math.prod(self.tokens_spatial_shape_final), config.decoder_hidden_size)
)
self.decoder_block = HieraStage(
config=config,
hidden_size=config.decoder_hidden_size,
hidden_size_output=config.decoder_hidden_size,
num_heads=config.decoder_num_heads,
depth=config.decoder_depth,
use_mask_unit_attn=False,
drop_path=[0.0] * config.decoder_depth,
query_stride=[1] * config.decoder_depth,
window_size=0,
)
self.decoder_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
# patch stride of prediction
self.pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
pred_dim = (self.pred_stride ** len(config.query_stride)) * config.num_channels
self.decoder_pred = nn.Linear(config.decoder_hidden_size, pred_dim)
def forward(
self,
encoder_hidden_states: torch.Tensor,
bool_masked_pos: torch.BoolTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, torch.BoolTensor]:
# Embed tokens
hidden_states = self.decoder_embeddings(encoder_hidden_states)
# Combine visible and bool_masked_pos tokens
# hidden_states : [batch_size, num_mask_units_visible, *mask_unit_spatial_shape_final, decoder_hidden_size]
# bool_masked_pos: [batch_size, num_mask_units]
mask_unit_height, mask_unit_width, decoder_hidden_size = hidden_states.shape[2:]
batch_size, num_mask_units = bool_masked_pos.shape
decoder_hidden_states = torch.zeros(
batch_size,
num_mask_units,
mask_unit_height,
mask_unit_width,
decoder_hidden_size,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
mask_tokens = self.mask_token.view(1, 1, 1, 1, -1)
bool_masked_pos = bool_masked_pos.reshape(batch_size, num_mask_units, 1, 1, 1)
bool_masked_pos = bool_masked_pos.expand(-1, -1, mask_unit_height, mask_unit_width, decoder_hidden_size)
decoder_hidden_states[bool_masked_pos] = hidden_states.flatten()
decoder_hidden_states = (
1 - bool_masked_pos.float()
) * mask_tokens + bool_masked_pos.float() * decoder_hidden_states
# Get back spatial order
hidden_states = undo_windowing(
decoder_hidden_states,
self.tokens_spatial_shape_final,
self.mask_unit_spatial_shape_final,
)
bool_masked_pos = undo_windowing(
bool_masked_pos[..., 0:1],
self.tokens_spatial_shape_final,
self.mask_unit_spatial_shape_final,
)
# Flatten
hidden_states = hidden_states.reshape(hidden_states.shape[0], -1, hidden_states.shape[-1])
bool_masked_pos = bool_masked_pos.view(hidden_states.shape[0], -1)
# Add pos embed
hidden_states = hidden_states + self.decoder_position_embeddings
# Apply decoder blocks
hidden_states, attn_weights = self.decoder_block(
hidden_states, head_mask=head_mask, output_attentions=output_attentions
)
hidden_states = self.decoder_norm(hidden_states)
# Predictor projection
hidden_states = self.decoder_pred(hidden_states)
return hidden_states, bool_masked_pos
class HieraMultiScaleHead(nn.Module):
def __init__(self, config: HieraConfig):
super().__init__()
self.mask_unit_spatial_shape_final = [
i // s ** (config.num_query_pool) for i, s in zip(config.masked_unit_size, config.query_stride)
]
self.stage_dimensions = [
int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
]
current_masked_unit_size = config.masked_unit_size
self.multi_scale_fusion_heads = nn.ModuleList()
for idx in range(config.num_query_pool):
kernel = [i // s for i, s in zip(current_masked_unit_size, self.mask_unit_spatial_shape_final)]
current_masked_unit_size = [i // s for i, s in zip(current_masked_unit_size, config.query_stride)]
self.multi_scale_fusion_heads.append(
nn.Conv2d(
self.stage_dimensions[idx],
self.stage_dimensions[-1],
kernel_size=kernel,
stride=kernel,
)
)
self.multi_scale_fusion_heads.append(nn.Identity())
def apply_fusion_head(self, head: nn.Module, hidden_states: torch.Tensor) -> torch.Tensor:
if isinstance(head, nn.Identity):
return hidden_states
# Doing explicit to avoid problems with torch.fx
batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size = hidden_states.shape
# From: [batch_size, num_mask_units, mask_unit_height, mask_unit_width, hidden_size]
# To: head([batch_size * num_mask_units, hidden_size, mask_unit_height, mask_unit_width])
hidden_states = hidden_states.reshape(
batch_size * num_mask_units, mask_unit_height, mask_unit_width, hidden_size
)
hidden_states = hidden_states.permute(0, 3, 1, 2)
hidden_states = head(hidden_states)
# Restore original layout
hidden_states = hidden_states.permute(0, 2, 3, 1)
mask_unit_height_final, mask_unit_width_final, hidden_size = hidden_states.shape[1:]
hidden_states = hidden_states.reshape(
batch_size, num_mask_units, mask_unit_height_final, mask_unit_width_final, hidden_size
)
return hidden_states
def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor:
# Multi-scale fusion
hidden_states = 0.0
for head, feature_map in zip(self.multi_scale_fusion_heads, feature_maps):
hidden_states = hidden_states + self.apply_fusion_head(head, feature_map)
return hidden_states
@add_start_docstrings(
"""The Hiera Model transformer with the decoder on top for self-supervised pre-training.
<Tip>
Note that we provide a script to pre-train this model on custom data in our [examples
directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
</Tip>
""",
HIERA_START_DOCSTRING,
)
class HieraForPreTraining(HieraPreTrainedModel):
def __init__(self, config: HieraConfig) -> None:
super().__init__(config)
# Encoder
self.hiera = HieraModel(config, add_pooling_layer=False, is_mae=True)
self.encoder_norm = nn.LayerNorm(self.hiera.num_features, eps=config.layer_norm_eps)
# Multi-scale fusion heads
self.multiscale_fusion = HieraMultiScaleHead(config)
# Decoder
self.decoder = HieraDecoder(config)
self.pred_stride = self.decoder.pred_stride
# Initialize weights and apply final processing
self.post_init()
def get_pixel_label_2d(self, pixel_values: torch.Tensor, bool_masked_pos: torch.BoolTensor) -> torch.Tensor:
# bool_masked_pos (boolean tensor): True means *masked*
pixel_values = pixel_values.permute(0, 2, 3, 1)
size = self.pred_stride
label = pixel_values.unfold(1, size, size).unfold(2, size, size)
label = label.flatten(1, 2).flatten(2)
label = label[bool_masked_pos]
if self.config.normalize_pixel_loss:
mean = label.mean(dim=-1, keepdim=True)
var = label.var(dim=-1, keepdim=True)
label = (label - mean) / (var + 1.0e-6) ** 0.5
return label
def forward_loss(self, pixel_values: torch.Tensor, logits: torch.Tensor, bool_masked_pos: torch.BoolTensor):
# We invert the bool_masked_pos such that 1.0 is *masked*
bool_masked_pos = ~bool_masked_pos
label = self.get_pixel_label_2d(pixel_values, bool_masked_pos)
logits = logits[bool_masked_pos]
loss = (logits - label) ** 2
loss = loss.mean()
return loss
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=HieraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
noise: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, HieraForPreTrainingOutput]:
r"""
noise (`torch.FloatTensor` of shape `(batch_size, num_mask_units)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
when is_mae is set to True.
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, HieraForPreTraining
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-mae-hf")
>>> model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> loss = outputs.loss
>>> print(list(logits.shape))
[1, 196, 768]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.hiera(
pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=True,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
feature_maps = outputs[-1]
bool_masked_pos = outputs[1]
ids_to_restore = outputs[2]
# Take only the query pooled and last hidden states
feature_maps = feature_maps[1 : self.hiera.config.num_query_pool + 1] + (feature_maps[-1],)
fused_hidden_states = self.multiscale_fusion(feature_maps)
fused_hidden_states = self.encoder_norm(fused_hidden_states)
# Reconstruct pixel values
logits, bool_masked_pos = self.decoder(
fused_hidden_states,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
)
loss = self.forward_loss(pixel_values, logits, bool_masked_pos)
if not return_dict:
output = (logits, bool_masked_pos, ids_to_restore)
if output_hidden_states:
output = output + (outputs[3],)
if output_attentions:
output = output + (outputs[4],)
if output_hidden_states:
output = output + (outputs[-1],)
return ((loss,) + output) if loss is not None else output
return HieraForPreTrainingOutput(
loss=loss,
logits=logits,
bool_masked_pos=bool_masked_pos,
ids_restore=ids_to_restore,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states if output_hidden_states else None,
)
@add_start_docstrings(
"""
Hiera Model transformer with an image classification head on top (a linear layer on top of the final hidden state with
average pooling) e.g. for ImageNet.
<Tip>
Note that it's possible to fine-tune Hiera on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
HIERA_START_DOCSTRING,
)
class HieraForImageClassification(HieraPreTrainedModel):
def __init__(self, config: HieraConfig) -> None:
super().__init__(config)
self.num_labels = config.num_labels
self.hiera = HieraModel(config, add_pooling_layer=True, is_mae=False)
# Classifier head
self.classifier = (
nn.Linear(self.hiera.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(HIERA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_IMAGE_CLASS_CHECKPOINT,
output_type=HieraForImageClassificationOutput,
config_class=_CONFIG_FOR_DOC,
expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
)
def forward(
self,
pixel_values,
head_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[tuple, HieraForImageClassificationOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs = self.hiera(
pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)
pooled_output = outputs[1]
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return HieraForImageClassificationOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
reshaped_hidden_states=outputs.reshaped_hidden_states,
)
@add_start_docstrings(
"""
Hiera backbone, to be used with frameworks like DETR and MaskFormer.
""",
HIERA_START_DOCSTRING,
)
class HieraBackbone(HieraPreTrainedModel, BackboneMixin):
def __init__(self, config: HieraConfig):
super().__init__(config)
super()._init_backbone(config)
self.num_features = [config.embed_dim] + [
int(config.embed_dim * config.embed_dim_multiplier**i) for i in range(len(config.depths))
]
self.embeddings = HieraEmbeddings(config, is_mae=False)
self.encoder = HieraEncoder(config)
# Add layer norms to hidden states of out_features
hidden_states_norms = {}
for stage, num_channels in zip(self._out_features, self.channels):
hidden_states_norms[stage] = nn.LayerNorm(num_channels)
self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(
self,
pixel_values: torch.Tensor,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> BackboneOutput:
"""
Returns:
Examples:
```python
>>> from transformers import AutoImageProcessor, AutoBackbone
>>> import torch
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> processor = AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-hf")
>>> model = AutoBackbone.from_pretrained(
... "facebook/hiera-tiny-224-hf", out_features=["stage1", "stage2", "stage3", "stage4"]
... )
>>> inputs = processor(image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> feature_maps = outputs.feature_maps
>>> list(feature_maps[-1].shape)
[1, 768, 7, 7]
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
embedding_output, _, _ = self.embeddings(pixel_values)
outputs = self.encoder(
embedding_output,
head_mask=None,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = outputs[-1]
feature_maps = ()
for stage, hidden_state in zip(self.stage_names, hidden_states):
if stage in self.out_features:
batch_size, height, width, num_channels = hidden_state.shape
hidden_state = hidden_state.view(batch_size, height * width, num_channels)
hidden_state = self.hidden_states_norms[stage](hidden_state)
hidden_state = hidden_state.view(batch_size, height, width, num_channels)
hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
feature_maps += (hidden_state,)
if not return_dict:
output = (feature_maps,)
if output_hidden_states:
output += (outputs[1],)
if output_attentions:
output += (outputs[2],)
return output
return BackboneOutput(
feature_maps=feature_maps,
hidden_states=outputs[1] if output_hidden_states else None,
attentions=outputs[2] if output_attentions else None,
)
...@@ -4583,6 +4583,41 @@ class GroupViTVisionModel(metaclass=DummyObject): ...@@ -4583,6 +4583,41 @@ class GroupViTVisionModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class HieraBackbone(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HieraForImageClassification(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HieraForPreTraining(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HieraModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HieraPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class HubertForCTC(metaclass=DummyObject): class HubertForCTC(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -138,6 +138,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -138,6 +138,7 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"gpt2", "gpt2",
"gpt_neo", "gpt_neo",
"gptj", "gptj",
"hiera",
"hubert", "hubert",
"layoutlm", "layoutlm",
"llama", "llama",
......
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing suite for the PyTorch Hiera model."""
import math
import unittest
from typing import Dict, List, Tuple
from transformers import HieraConfig
from transformers.testing_utils import (
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import (
cached_property,
is_torch_available,
is_vision_available,
)
from ...test_backbone_common import BackboneTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import HieraBackbone, HieraForImageClassification, HieraForPreTraining, HieraModel
if is_vision_available():
from PIL import Image
from transformers import AutoImageProcessor
class HieraModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=[64, 64],
mlp_ratio=1.0,
num_channels=3,
depths=[1, 1, 1, 1],
patch_stride=[4, 4],
patch_size=[7, 7],
patch_padding=[3, 3],
masked_unit_size=[8, 8],
num_heads=[1, 1, 1, 1],
embed_dim_multiplier=2.0,
is_training=True,
use_labels=True,
embed_dim=8,
hidden_act="gelu",
decoder_hidden_size=2,
decoder_depth=1,
decoder_num_heads=1,
initializer_range=0.02,
scope=None,
type_sequence_label_size=10,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.mlp_ratio = mlp_ratio
self.num_channels = num_channels
self.depths = depths
self.patch_stride = patch_stride
self.patch_size = patch_size
self.patch_padding = patch_padding
self.masked_unit_size = masked_unit_size
self.num_heads = num_heads
self.embed_dim_multiplier = embed_dim_multiplier
self.is_training = is_training
self.use_labels = use_labels
self.embed_dim = embed_dim
self.hidden_act = hidden_act
self.decoder_hidden_size = decoder_hidden_size
self.decoder_depth = decoder_depth
self.decoder_num_heads = decoder_num_heads
self.initializer_range = initializer_range
self.scope = scope
self.type_sequence_label_size = type_sequence_label_size
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size[0], self.image_size[1]])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return HieraConfig(
embed_dim=self.embed_dim,
image_size=self.image_size,
patch_stride=self.patch_stride,
patch_size=self.patch_size,
patch_padding=self.patch_padding,
masked_unit_size=self.masked_unit_size,
mlp_ratio=self.mlp_ratio,
num_channels=self.num_channels,
depths=self.depths,
num_heads=self.num_heads,
embed_dim_multiplier=self.embed_dim_multiplier,
hidden_act=self.hidden_act,
decoder_hidden_size=self.decoder_hidden_size,
decoder_depth=self.decoder_depth,
decoder_num_heads=self.decoder_num_heads,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = HieraModel(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
tokens_spatial_shape = [i // s for i, s in zip(self.image_size, config.patch_stride)]
expected_seq_len = math.prod(tokens_spatial_shape) // math.prod(config.query_stride) ** (config.num_query_pool)
expected_dim = int(config.embed_dim * config.embed_dim_multiplier ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_backbone(self, config, pixel_values, labels):
model = HieraBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
num_patches = config.image_size[0] // config.patch_stride[0] // config.masked_unit_size[0]
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[0], num_patches, num_patches]
)
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
# verify backbone works with out_features=None
config.out_features = None
model = HieraBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(
list(result.feature_maps[0].shape), [self.batch_size, model.channels[-1], num_patches, num_patches]
)
# verify channels
self.parent.assertEqual(len(model.channels), 1)
def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = HieraForPreTraining(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
pred_stride = config.patch_stride[-1] * (config.query_stride[-1] ** config.num_query_pool)
num_patches = self.image_size[0] // pred_stride
self.parent.assertEqual(
result.logits.shape, (self.batch_size, num_patches**2, self.num_channels * pred_stride**2)
)
# test greyscale images
config.num_channels = 1
model = HieraForPreTraining(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size[0], self.image_size[0]])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, num_patches**2, pred_stride**2))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = HieraForImageClassification(config)
model.to(torch_device)
model.eval()
result = model(pixel_values, labels=labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# test greyscale images
config.num_channels = 1
model = HieraForImageClassification(config)
model.to(torch_device)
model.eval()
pixel_values = floats_tensor([self.batch_size, 1, self.image_size[0], self.image_size[0]])
result = model(pixel_values)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class HieraModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as Hiera does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (
(
HieraModel,
HieraBackbone,
HieraForImageClassification,
HieraForPreTraining,
)
if is_torch_available()
else ()
)
pipeline_model_mapping = (
{"image-feature-extraction": HieraModel, "image-classification": HieraForImageClassification}
if is_torch_available()
else {}
)
fx_compatible = True
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = HieraModelTester(self)
self.config_tester = ConfigTester(self, config_class=HieraConfig, has_text_modality=False)
def test_config(self):
self.config_tester.create_and_test_config_to_json_string()
self.config_tester.create_and_test_config_to_json_file()
self.config_tester.create_and_test_config_from_and_save_pretrained()
self.config_tester.create_and_test_config_with_num_labels()
self.config_tester.check_config_can_be_init_without_params()
self.config_tester.check_config_arguments_init()
# Overriding as Hiera `get_input_embeddings` returns HieraPatchEmbeddings
def test_model_get_set_embeddings(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
# Overriding as attention shape depends on patch_stride and mask_unit_size
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
expected_num_attentions = len(self.model_tester.depths)
self.assertEqual(len(attentions), expected_num_attentions)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
seq_len = math.prod([i // s for i, s in zip(config.image_size, config.patch_stride)])
mask_unit_area = math.prod(config.masked_unit_size)
num_windows = seq_len // mask_unit_area
if model_class.__name__ == "HieraForPreTraining":
num_windows = int(num_windows * (1 - config.mask_ratio))
seq_len = int(num_windows * mask_unit_area)
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], num_windows, mask_unit_area, seq_len // num_windows],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
# also another +1 for reshaped_hidden_states
added_hidden_states = 1 if model_class.__name__ == "HieraBackbone" else 2
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.attentions
self.assertEqual(len(self_attentions), expected_num_attentions)
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_heads[0], num_windows, mask_unit_area, seq_len // num_windows],
)
# Overriding as attention shape depends on patch_stride and mask_unit_size
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class, image_size):
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# Hiera has a different seq_length
patch_size = config.patch_stride
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
if model_class.__name__ == "HieraForPreTraining":
mask_unit_area = math.prod(config.masked_unit_size)
num_windows = num_patches // mask_unit_area
num_windows = int(num_windows * (1 - config.mask_ratio))
num_patches = int(num_windows * mask_unit_area)
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
if not model_class.__name__ == "HieraBackbone":
reshaped_hidden_states = outputs.reshaped_hidden_states
self.assertEqual(len(reshaped_hidden_states), expected_num_layers)
batch_size = reshaped_hidden_states[0].shape[0]
num_channels = reshaped_hidden_states[0].shape[-1]
reshaped_hidden_states = reshaped_hidden_states[0].view(batch_size, -1, num_channels)
self.assertListEqual(
list(reshaped_hidden_states.shape[-2:]),
[num_patches, self.model_tester.embed_dim],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
image_size = self.model_tester.image_size
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class, image_size)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class, image_size)
# Overriding since HieraForPreTraining outputs bool_masked_pos which has to be converted to float in the msg
def test_model_outputs_equivalence(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
with torch.no_grad():
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(
tuple_object.values(), dict_object.values()
):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object.float() - dict_object.float()))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
additional_kwargs = {}
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
additional_kwargs["output_hidden_states"] = True
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
if self.has_attentions:
# Removing "output_hidden_states"
del additional_kwargs["output_hidden_states"]
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
additional_kwargs["output_attentions"] = True
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
additional_kwargs["output_hidden_states"] = True
check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs)
@unittest.skip(reason="Hiera Transformer does not use feedforward chunking")
def test_feed_forward_chunking(self):
pass
@unittest.skip(reason="Hiera does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_backbone(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_backbone(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in ["facebook/hiera-tiny-224-hf"]:
model = HieraModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_torch
@require_vision
class HieraModelIntegrationTest(unittest.TestCase):
@cached_property
def default_image_processor(self):
return AutoImageProcessor.from_pretrained("facebook/hiera-tiny-224-in1k-hf") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = HieraForImageClassification.from_pretrained("facebook/hiera-tiny-224-in1k-hf").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
expected_pixel_values = torch.tensor(
[
[[0.2967, 0.4679, 0.4508], [0.3309, 0.4337, 0.3309], [0.3309, 0.3823, 0.3309]],
[[-1.5455, -1.4930, -1.5455], [-1.5280, -1.4755, -1.5980], [-1.5630, -1.5280, -1.4755]],
[[-0.6367, -0.4973, -0.5321], [-0.7936, -0.6715, -0.6715], [-0.8284, -0.7413, -0.5670]],
]
).to(torch_device)
self.assertTrue(torch.allclose(inputs.pixel_values[0, :3, :3, :3], expected_pixel_values, atol=1e-4))
# forward pass
with torch.no_grad():
outputs = model(**inputs)
# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([[0.8028, 0.2409, -0.2254, -0.3712, -0.2848]]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :5], expected_slice, atol=1e-4))
def test_inference_interpolate_pos_encoding(self):
model = HieraModel.from_pretrained("facebook/hiera-tiny-224-hf").to(torch_device)
image_processor = AutoImageProcessor.from_pretrained(
"facebook/hiera-tiny-224-hf", size={"shortest_edge": 448}, crop_size={"height": 448, "width": 448}
)
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values.to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(pixel_values, interpolate_pos_encoding=True)
# verify the logits
expected_shape = torch.Size((1, 196, 768))
self.assertEqual(outputs.last_hidden_state.shape, expected_shape)
expected_slice = torch.tensor(
[[1.8522, 0.1532, 0.3849], [2.7352, -0.1941, 0.1848], [1.5859, -0.0773, 0.0168]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_for_pretraining(self):
# make random mask reproducible
torch.manual_seed(2)
model = HieraForPreTraining.from_pretrained("facebook/hiera-tiny-224-mae-hf").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)
config = model.config
mask_spatial_shape = [
i // s // ms for i, s, ms in zip(config.image_size, config.patch_stride, config.masked_unit_size)
]
num_windows = math.prod(mask_spatial_shape)
noise = torch.rand(1, num_windows).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs, noise=noise)
# verify the logits
expected_shape = torch.Size((1, 196, 768))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor(
[
[1.6407, 1.6506, 1.6541, 1.6617, 1.6703],
[1.9730, 1.9842, 1.9848, 1.9896, 1.9947],
[1.5949, 1.8262, 1.2602, 1.4801, 1.4448],
[1.2341, 1.7907, 0.8618, 1.5202, 1.4523],
[2.0140, 1.9846, 1.9434, 1.9019, 1.8648],
]
)
self.assertTrue(torch.allclose(outputs.logits[0, :5, :5], expected_slice.to(torch_device), atol=1e-4))
@require_torch
class HieraBackboneTest(unittest.TestCase, BackboneTesterMixin):
all_model_classes = (HieraBackbone,) if is_torch_available() else ()
config_class = HieraConfig
def setUp(self):
self.model_tester = HieraModelTester(self)
...@@ -16,6 +16,7 @@ import collections ...@@ -16,6 +16,7 @@ import collections
import copy import copy
import gc import gc
import inspect import inspect
import math
import os import os
import os.path import os.path
import random import random
...@@ -55,6 +56,7 @@ from transformers.models.auto.modeling_auto import ( ...@@ -55,6 +56,7 @@ from transformers.models.auto.modeling_auto import (
MODEL_FOR_MASKED_LM_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES,
MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES, MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES,
MODEL_FOR_PRETRAINING_MAPPING_NAMES,
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES, MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
...@@ -194,6 +196,14 @@ class ModelTesterMixin: ...@@ -194,6 +196,14 @@ class ModelTesterMixin:
} }
elif model_class.__name__ in get_values(MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES): elif model_class.__name__ in get_values(MODEL_FOR_AUDIO_XVECTOR_MAPPING_NAMES):
inputs_dict.pop("attention_mask") inputs_dict.pop("attention_mask")
elif model_class.__name__ == MODEL_FOR_PRETRAINING_MAPPING_NAMES["hiera"]:
config = self.model_tester.get_config()
mask_spatial_shape = [
i // s // ms for i, s, ms in zip(config.image_size, config.patch_stride, config.masked_unit_size)
]
num_windows = math.prod(mask_spatial_shape)
torch.manual_seed(0)
inputs_dict["noise"] = torch.rand(self.model_tester.batch_size, num_windows)
if return_labels: if return_labels:
if model_class.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): if model_class.__name__ in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
...@@ -1163,6 +1173,7 @@ class ModelTesterMixin: ...@@ -1163,6 +1173,7 @@ class ModelTesterMixin:
"token_type_ids", "token_type_ids",
"visual_feats", "visual_feats",
"visual_pos", "visual_pos",
"noise",
] ]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
......
...@@ -997,6 +997,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ ...@@ -997,6 +997,7 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
"DinatBackbone", "DinatBackbone",
"Dinov2Backbone", "Dinov2Backbone",
"FocalNetBackbone", "FocalNetBackbone",
"HieraBackbone",
"MaskFormerSwinBackbone", "MaskFormerSwinBackbone",
"MaskFormerSwinConfig", "MaskFormerSwinConfig",
"MaskFormerSwinModel", "MaskFormerSwinModel",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment