Unverified Commit 5b40a37b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

Add TF ViT MAE (#16255)



* ported TFViTMAEIntermediate and TFViTMAEOutput.

* added TFViTMAEModel and TFViTMAEDecoder.

* feat: added a noise argument in the implementation for reproducibility.

* feat: vit mae models with an additional noise argument for reproducibility.
Co-authored-by: default avatarariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 7a9ef818
...@@ -260,7 +260,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -260,7 +260,7 @@ Flax), PyTorch, and/or TensorFlow.
| VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ | | VisionTextDualEncoder | ❌ | ❌ | ✅ | ❌ | ✅ |
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
| ViT | ❌ | ❌ | ✅ | ✅ | ✅ | | ViT | ❌ | ❌ | ✅ | ✅ | ✅ |
| ViTMAE | ❌ | ❌ | ✅ | | ❌ | | ViTMAE | ❌ | ❌ | ✅ | | ❌ |
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ | | XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
......
...@@ -41,13 +41,16 @@ fine-tuning, one can directly plug in the weights into a [`ViTForImageClassifica ...@@ -41,13 +41,16 @@ fine-tuning, one can directly plug in the weights into a [`ViTForImageClassifica
- Note that the encoder of MAE is only used to encode the visual patches. The encoded patches are then concatenated with mask tokens, which the decoder (which also - Note that the encoder of MAE is only used to encode the visual patches. The encoded patches are then concatenated with mask tokens, which the decoder (which also
consists of Transformer blocks) takes as input. Each mask token is a shared, learned vector that indicates the presence of a missing patch to be predicted. Fixed consists of Transformer blocks) takes as input. Each mask token is a shared, learned vector that indicates the presence of a missing patch to be predicted. Fixed
sin/cos position embeddings are added both to the input of the encoder and the decoder. sin/cos position embeddings are added both to the input of the encoder and the decoder.
- For a visual understanding of how MAEs work you can check out this [post](https://keras.io/examples/vision/masked_image_modeling/).
<img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png" <img src="https://user-images.githubusercontent.com/11435359/146857310-f258c86c-fde6-48e8-9cee-badd2b21bd2c.png"
alt="drawing" width="600"/> alt="drawing" width="600"/>
<small> MAE architecture. Taken from the <a href="https://arxiv.org/abs/2111.06377">original paper.</a> </small> <small> MAE architecture. Taken from the <a href="https://arxiv.org/abs/2111.06377">original paper.</a> </small>
This model was contributed by [nielsr](https://huggingface.co/nielsr). The original code can be found [here](https://github.com/facebookresearch/mae). This model was contributed by [nielsr](https://huggingface.co/nielsr). TensorFlow version of the model was contributed by [sayakpaul](https://github.com/sayakpaul) and
[ariG23498](https://github.com/ariG23498) (equal contribution). The original code can be found [here](https://github.com/facebookresearch/mae).
## ViTMAEConfig ## ViTMAEConfig
...@@ -64,3 +67,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi ...@@ -64,3 +67,15 @@ This model was contributed by [nielsr](https://huggingface.co/nielsr). The origi
[[autodoc]] transformers.ViTMAEForPreTraining [[autodoc]] transformers.ViTMAEForPreTraining
- forward - forward
## TFViTMAEModel
[[autodoc]] TFViTMAEModel
- call
## TFViTMAEForPreTraining
[[autodoc]] transformers.TFViTMAEForPreTraining
- call
...@@ -2135,6 +2135,13 @@ if is_tf_available(): ...@@ -2135,6 +2135,13 @@ if is_tf_available():
"TFViTPreTrainedModel", "TFViTPreTrainedModel",
] ]
) )
_import_structure["models.vit_mae"].extend(
[
"TFViTMAEForPreTraining",
"TFViTMAEModel",
"TFViTMAEPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4170,6 +4177,7 @@ if TYPE_CHECKING: ...@@ -4170,6 +4177,7 @@ if TYPE_CHECKING:
) )
from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel from .models.vision_encoder_decoder import TFVisionEncoderDecoderModel
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
......
...@@ -70,6 +70,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -70,6 +70,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("blenderbot", "TFBlenderbotModel"), ("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"), ("blenderbot-small", "TFBlenderbotSmallModel"),
("vit", "TFViTModel"), ("vit", "TFViTModel"),
("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
] ]
...@@ -100,6 +101,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( ...@@ -100,6 +101,7 @@ TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
("tapas", "TFTapasForMaskedLM"), ("tapas", "TFTapasForMaskedLM"),
("funnel", "TFFunnelForPreTraining"), ("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"), ("mpnet", "TFMPNetForMaskedLM"),
("vit_mae", "TFViTMAEForPreTraining"),
] ]
) )
......
...@@ -33,6 +33,12 @@ if is_torch_available(): ...@@ -33,6 +33,12 @@ if is_torch_available():
"ViTMAEPreTrainedModel", "ViTMAEPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_vit_mae"] = [
"TFViTMAEForPreTraining",
"TFViTMAEModel",
"TFViTMAEPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig from .configuration_vit_mae import VIT_MAE_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTMAEConfig
...@@ -46,6 +52,9 @@ if TYPE_CHECKING: ...@@ -46,6 +52,9 @@ if TYPE_CHECKING:
ViTMAEPreTrainedModel, ViTMAEPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_vit_mae import TFViTMAEForPreTraining, TFViTMAEModel, TFViTMAEPreTrainedModel
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2022 Facebook AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 ViT MAE (masked autoencoder) model."""
import collections.abc
import math
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutput
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
get_initializer,
keras_serializable,
unpack_inputs,
)
from ...tf_utils import shape_list
from ...utils import logging
from .configuration_vit_mae import ViTMAEConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTMAEConfig"
_CHECKPOINT_FOR_DOC = "facebook/vit-mae-base"
@dataclass
class TFViTMAEModelOutput(ModelOutput):
"""
Class for TFViTMAEModel's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor containing the original index of the (shuffled) masked patches.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
last_hidden_state: tf.Tensor = None
mask: tf.Tensor = None
ids_restore: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFViTMAEDecoderOutput(ModelOutput):
"""
Class for TFViTMAEDecoder's outputs, with potential hidden states and attentions.
Args:
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
@dataclass
class TFViTMAEForPreTrainingOutput(ModelOutput):
"""
Class for TFViTMAEForPreTraining's outputs, with potential hidden states and attentions.
Args:
loss (`tf.Tensor` of shape `(1,)`):
Pixel reconstruction loss.
logits (`tf.Tensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
Pixel reconstruction logits.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
ids_restore (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor containing the original index of the (shuffled) masked patches.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus
the initial embedding outputs.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
mask: tf.Tensor = None
ids_restore: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
# copied from transformers.models.vit.modeling_tf_vit.to_2tuple
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
Args:
embed_dim (`int`):
Embedding dimension.
grid_size (`int`):
The grid height and width.
add_cls_token (`bool`, *optional*, defaults to `False`):
Whether or not to add a classification (CLS) token.
Returns:
(`tf.Tensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the position
embeddings (with or without classification token)
"""
grid_h = tf.range(grid_size, dtype=tf.float32)
grid_w = tf.range(grid_size, dtype=tf.float32)
grid = tf.meshgrid(grid_w, grid_h) # here w goes first
grid = tf.stack(grid, axis=0)
grid = tf.reshape(grid, [2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if add_cls_token:
pos_embed = tf.concat([tf.zeros((1, embed_dim)), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = tf.concat([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
omega = tf.range(embed_dim // 2, dtype="float32")
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = tf.reshape(pos, [-1]) # (M,)
out = tf.einsum("m,d->md", pos, omega) # (M, D/2), outer product
# half of the positions get sinusoidal pattern and the rest gets
# cosine pattern and then they are concatenated
emb_sin = tf.sin(out) # (M, D/2)
emb_cos = tf.cos(out) # (M, D/2)
emb = tf.concat([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class TFViTMAEEmbeddings(tf.keras.layers.Layer):
"""
Construct the CLS token, position and patch embeddings.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
self.num_patches = self.patch_embeddings.num_patches
self.config = config
def build(self, input_shape: tf.TensorShape):
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
trainable=True,
name="cls_token",
)
self.position_embeddings = self.add_weight(
shape=(1, self.num_patches + 1, self.config.hidden_size),
initializer="zeros",
trainable=False, # fixed sin-cos embedding
name="position_embeddings",
)
pos_embed = get_2d_sincos_pos_embed(
self.position_embeddings.shape[-1],
int(self.patch_embeddings.num_patches**0.5),
add_cls_token=True,
)[None, ...]
self.position_embeddings.assign(pos_embed)
super().build(input_shape)
def random_masking(self, sequence: tf.Tensor, noise: Optional[tf.Tensor] = None):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
noise.
Args:
sequence (`tf.Tensor` of shape `(batch_size, sequence_length, dim)`)
noise (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
"""
batch_size, seq_length, dim = shape_list(sequence)
len_keep = int(seq_length * (1 - self.config.mask_ratio))
if noise is None:
noise = tf.random.uniform(shape=(batch_size, seq_length), minval=0.0, maxval=1.0) # noise in [0, 1)
# sort noise for each sample
ids_shuffle = tf.argsort(noise, axis=1) # ascend: small is keep, large is remove
ids_restore = tf.argsort(ids_shuffle, axis=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
sequence_masked = tf.gather(
sequence,
axis=1,
batch_dims=1,
indices=ids_keep,
)
# generate the binary mask: 0 is keep, 1 is remove
# this hack is needed because TF's EagerTensors don't support
# assignment
mask_keep = tf.zeros((batch_size, len_keep))
mask_remove = tf.ones((batch_size, seq_length - len_keep))
mask = tf.concat([mask_keep, mask_remove], axis=-1)
# unshuffle to get the binary mask
mask = tf.gather(mask, axis=1, batch_dims=1, indices=ids_restore)
return sequence_masked, mask, ids_restore
def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
embeddings = tf.concat([cls_tokens, embeddings], axis=1)
return embeddings, mask, ids_restore
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFPatchEmbeddings(tf.keras.layers.Layer):
"""
Image to Patch Embedding.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
image_size = to_2tuple(config.image_size)
patch_size = to_2tuple(config.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.num_channels = config.num_channels
self.embed_dim = config.hidden_size
self.config = config
self.projection = tf.keras.layers.Conv2D(
filters=self.embed_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
padding="valid",
data_format="channels_last",
kernel_initializer="glorot_uniform", # following torch.nn.Linear
bias_initializer="zeros",
name="projection",
)
def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
projection = self.projection(pixel_values)
# Change the 2D spatial dimensions to a single temporal dimension.
# shape = (batch_size, num_patches, out_channels=embed_dim)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
return x
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfAttention with ViT->ViTMAE
class TFViTMAESelfAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(inputs=attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
# Copied from transformers.models.vit.modeling_tf_vit.TFViTSelfOutput with ViT->ViTMAE
class TFViTMAESelfOutput(tf.keras.layers.Layer):
"""
The residual connection is defined in TFViTMAELayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states
# Copied from transformers.models.vit.modeling_tf_vit.TFViTAttention with ViT->ViTMAE
class TFViTMAEAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFViTMAESelfAttention(config, name="attention")
self.dense_output = TFViTMAESelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.vit.modeling_tf_vit.TFViTIntermediate with ViT->ViTMAE
class TFViTMAEIntermediate(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.vit.modeling_tf_vit.TFViTOutput with ViT->ViTMAE
class TFViTMAEOutput(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = hidden_states + input_tensor
return hidden_states
# Copied from transformers.models.vit.modeling_tf_vit.TFViTLayer with ViT->ViTMAE
class TFViTMAELayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.attention = TFViTMAEAttention(config, name="attention")
self.intermediate = TFViTMAEIntermediate(config, name="intermediate")
self.vit_output = TFViTMAEOutput(config, name="output")
self.layernorm_before = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_before"
)
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
# in ViTMAE, layernorm is applied before self-attention
input_tensor=self.layernorm_before(inputs=hidden_states),
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
# first residual connection
hidden_states = attention_output + hidden_states
# in ViTMAE, layernorm is also applied after self-attention
layer_output = self.layernorm_after(inputs=hidden_states)
intermediate_output = self.intermediate(hidden_states=layer_output)
# second residual connection is done here
layer_output = self.vit_output(
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
# Copied from transformers.models.vit.modeling_tf_vit.TFViTEncoder with ViT->ViTMAE
class TFViTMAEEncoder(tf.keras.layers.Layer):
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.layer = [TFViTMAELayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states=hidden_states,
head_mask=head_mask[i],
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
@keras_serializable
class TFViTMAEMainLayer(tf.keras.layers.Layer):
config_class = ViTMAEConfig
def __init__(self, config: ViTMAEConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embeddings = TFViTMAEEmbeddings(config, name="embeddings")
self.encoder = TFViTMAEEncoder(config, name="encoder")
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
@unpack_inputs
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
noise: tf.Tensor = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
embedding_output,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
if not return_dict:
return (sequence_output, mask, ids_restore) + encoder_outputs[1:]
return TFViTMAEModelOutput(
last_hidden_state=sequence_output,
mask=mask,
ids_restore=ids_restore,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFViTMAEPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTMAEConfig
base_model_prefix = "vit"
main_input_name = "pixel_values"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network. Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size),
dtype=tf.float32,
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
return self.call(inputs)
VIT_MAE_START_DOCSTRING = r"""
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
behavior.
<Tip>
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using [`tf.keras.Model.fit`] method which currently requires having all the
tensors in the first argument of the model call function: `model(inputs)`.
</Tip>
Args:
config ([`ViTMAEConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
"""
VIT_MAE_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` ``Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using [`AutoFeatureExtractor`]. See
[`AutoFeatureExtractor.__call__`] for details.
head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. This argument can be used
in eager mode, in graph mode the value will always be set to True.
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare ViTMAE Model transformer outputting raw hidden-states without any specific head on top.",
VIT_MAE_START_DOCSTRING,
)
class TFViTMAEModel(TFViTMAEPreTrainedModel):
def __init__(self, config: ViTMAEConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.vit = TFViTMAEMainLayer(config, name="vit")
def get_input_embeddings(self):
return self.vit.get_input_embeddings()
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFViTMAEModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
noise: tf.Tensor = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, TFViTMAEModel
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> model = TFViTMAEModel.from_pretrained("facebook/vit-mae-base")
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
outputs = self.vit(
pixel_values=pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
return outputs
class TFViTMAEDecoder(tf.keras.layers.Layer):
def __init__(self, config, num_patches, **kwargs):
super().__init__(**kwargs)
self.decoder_embed = tf.keras.layers.Dense(config.decoder_hidden_size, name="decoder_embed")
decoder_config = deepcopy(config)
decoder_config.hidden_size = config.decoder_hidden_size
decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
decoder_config.num_attention_heads = config.decoder_num_attention_heads
decoder_config.intermediate_size = config.decoder_intermediate_size
self.decoder_layers = [
TFViTMAELayer(decoder_config, name=f"decoder_layers.{j}") for j in range(config.decoder_num_hidden_layers)
]
self.decoder_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="decoder_norm")
self.decoder_pred = tf.keras.layers.Dense(
config.patch_size**2 * config.num_channels, name="decoder_pred"
) # encoder to decoder
self.config = config
self.num_patches = num_patches
def build(self, input_shape: tf.TensorShape):
self.mask_token = self.add_weight(
shape=(1, 1, self.config.decoder_hidden_size),
initializer=tf.random_normal_initializer(stddev=self.config.initializer_range),
trainable=True,
name="mask_token",
)
self.decoder_pos_embed = self.add_weight(
shape=(1, self.num_patches + 1, self.config.decoder_hidden_size),
initializer="zeros",
trainable=False,
name="decoder_pos_embed",
)
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.num_patches**0.5),
add_cls_token=True,
)[None, ...]
self.decoder_pos_embed.assign(decoder_pos_embed)
super().build(input_shape)
def call(
self,
hidden_states,
ids_restore,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
# embed tokens
x = self.decoder_embed(hidden_states)
# append mask tokens to sequence
mask_tokens = tf.tile(
self.mask_token,
(shape_list(x)[0], shape_list(ids_restore)[1] + 1 - shape_list(x)[1], 1),
)
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token
# add pos embed
hidden_states = x + self.decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.decoder_layers):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
head_mask=None,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = self.decoder_norm(hidden_states)
# predictor projection
logits = self.decoder_pred(hidden_states)
# remove cls token
logits = logits[:, 1:, :]
if not return_dict:
return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
return TFViTMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
@add_start_docstrings(
"The ViTMAE Model transformer with the decoder on top for self-supervised pre-training.",
VIT_MAE_START_DOCSTRING,
)
class TFViTMAEForPreTraining(TFViTMAEPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.vit = TFViTMAEMainLayer(config, name="vit")
self.decoder = TFViTMAEDecoder(
config,
num_patches=self.vit.embeddings.num_patches,
name="decoder",
)
def get_input_embeddings(self):
return self.vit.get_input_embeddings()
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def patchify(self, imgs):
"""
imgs: (batch_size, height, width, 3) x: (batch_size, num_patches, patch_size**2 *3)
"""
imgs = tf.cond(
tf.math.equal(shape_list(imgs)[1], 3), lambda: tf.transpose(imgs, perm=(0, 2, 3, 1)), lambda: imgs
)
p = self.vit.embeddings.patch_embeddings.patch_size[0]
tf.debugging.assert_equal(shape_list(imgs)[1], shape_list(imgs)[2])
tf.debugging.assert_equal(shape_list(imgs)[1] % p, 0)
h = w = shape_list(imgs)[2] // p
x = tf.reshape(imgs, (shape_list(imgs)[0], h, p, w, p, 3))
x = tf.einsum("nhpwqc->nhwpqc", x)
x = tf.reshape(x, (shape_list(imgs)[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (batch_size, num_patches, patch_size**2 *3) imgs: (batch_size, height, width, 3)
"""
p = self.vit.embeddings.patch_embeddings.patch_size[0]
h = w = int(shape_list(x)[1] ** 0.5)
tf.debugging.assert_equal(h * w, shape_list(x)[1])
x = tf.reshape(x, (shape_list(x)[0], h, w, p, p, 3))
x = tf.einsum("nhwpqc->nhpwqc", x)
imgs = tf.reshape(x, (shape_list(x)[0], h * p, h * p, 3))
return imgs
def forward_loss(self, imgs, pred, mask):
"""
imgs: [batch_size, height, width, 3] pred: [batch_size, num_patches, patch_size**2*3] mask: [N, L], 0 is keep,
1 is remove,
"""
target = self.patchify(imgs)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = tf.reduce_mean(loss, axis=-1) # [N, L], mean loss per patch
loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask) # mean loss on removed patches
return loss
@unpack_inputs
@add_start_docstrings_to_model_forward(VIT_MAE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFViTMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
noise: tf.Tensor = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Examples:
```python
>>> from transformers import AutoFeatureExtractor, TFViTMAEForPreTraining
>>> from PIL import Image
>>> import requests
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
>>> model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
>>> inputs = feature_extractor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> mask = outputs.mask
>>> ids_restore = outputs.ids_restore
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.vit(
pixel_values=pixel_values,
noise=noise,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask
decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
logits = decoder_outputs.logits
loss = self.forward_loss(pixel_values, logits, mask)
if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFViTMAEForPreTrainingOutput(
loss=loss,
logits=logits,
mask=mask,
ids_restore=ids_restore,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
...@@ -240,17 +240,20 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -240,17 +240,20 @@ class ViTMAEEmbeddings(nn.Module):
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range) torch.nn.init.normal_(self.cls_token, std=self.config.initializer_range)
def random_masking(self, sequence): def random_masking(self, sequence, noise=None):
""" """
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
noise. noise.
Args: Args:
sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`) sequence (`torch.LongTensor` of shape `(batch_size, sequence_length, dim)`)
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
mainly used for testing purposes to control randomness and maintain the reproducibility
""" """
batch_size, seq_length, dim = sequence.shape batch_size, seq_length, dim = sequence.shape
len_keep = int(seq_length * (1 - self.config.mask_ratio)) len_keep = int(seq_length * (1 - self.config.mask_ratio))
if noise is None:
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1] noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
# sort noise for each sample # sort noise for each sample
...@@ -269,7 +272,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -269,7 +272,7 @@ class ViTMAEEmbeddings(nn.Module):
return sequence_masked, mask, ids_restore return sequence_masked, mask, ids_restore
def forward(self, pixel_values): def forward(self, pixel_values, noise=None):
batch_size, num_channels, height, width = pixel_values.shape batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values) embeddings = self.patch_embeddings(pixel_values)
...@@ -277,7 +280,7 @@ class ViTMAEEmbeddings(nn.Module): ...@@ -277,7 +280,7 @@ class ViTMAEEmbeddings(nn.Module):
embeddings = embeddings + self.position_embeddings[:, 1:, :] embeddings = embeddings + self.position_embeddings[:, 1:, :]
# masking: length -> length * config.mask_ratio # masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings) embeddings, mask, ids_restore = self.random_masking(embeddings, noise)
# append cls token # append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :] cls_token = self.cls_token + self.position_embeddings[:, :1, :]
...@@ -668,6 +671,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -668,6 +671,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
noise=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -709,7 +713,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel): ...@@ -709,7 +713,7 @@ class ViTMAEModel(ViTMAEPreTrainedModel):
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output, mask, ids_restore = self.embeddings(pixel_values) embedding_output, mask, ids_restore = self.embeddings(pixel_values, noise=noise)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
...@@ -910,6 +914,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -910,6 +914,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
def forward( def forward(
self, self,
pixel_values=None, pixel_values=None,
noise=None,
head_mask=None, head_mask=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -941,6 +946,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel): ...@@ -941,6 +946,7 @@ class ViTMAEForPreTraining(ViTMAEPreTrainedModel):
outputs = self.vit( outputs = self.vit(
pixel_values, pixel_values,
noise=noise,
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -1987,6 +1987,27 @@ class TFViTPreTrainedModel(metaclass=DummyObject): ...@@ -1987,6 +1987,27 @@ class TFViTPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFViTMAEForPreTraining(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTMAEModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTMAEPreTrainedModel(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow ViTMAE model. """
import copy
import inspect
import json
import math
import os
import tempfile
import unittest
from importlib import import_module
import numpy as np
from transformers import ViTMAEConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow, torch_device
from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFViTMAEForPreTraining, TFViTMAEModel
from transformers.models.vit_mae.modeling_tf_vit_mae import to_2tuple
if is_vision_available():
from PIL import Image
from transformers import ViTFeatureExtractor
class TFViTMAEModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ViTMAEConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTMAEModel(config=config)
result = model(pixel_values, training=False)
# expected sequence length = (num_patches + 1) * (1 - config.mask_ratio), rounded above
# (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, self.hidden_size))
def create_and_check_for_pretraining(self, config, pixel_values, labels):
model = TFViTMAEForPreTraining(config)
result = model(pixel_values, training=False)
# expected sequence length = num_patches
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
expected_seq_len = num_patches
expected_num_channels = self.patch_size**2 * self.num_channels
self.parent.assertEqual(result.logits.shape, (self.batch_size, expected_seq_len, expected_num_channels))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, pixel_values, labels) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as ViTMAE does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFViTMAEModel, TFViTMAEForPreTraining) if is_tf_available() else ()
test_pruning = False
test_onnx = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = TFViTMAEModelTester(self)
self.config_tester = ConfigTester(self, config_class=ViTMAEConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="ViTMAE does not use inputs_embeds")
def test_inputs_embeds(self):
# ViTMAE does not use inputs_embeds
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def test_keyword_and_dict_args(self):
# make the mask reproducible
np.random.seed(2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
for model_class in self.all_model_classes:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
outputs_dict = model(inputs, noise=noise)
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
outputs_keywords = model(**inputs_keywords, noise=noise)
output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy()
self.assertLess(np.sum(np.abs(output_dict - output_keywords)), 1e-6)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def test_numpy_arrays_inputs(self):
# make the mask reproducible
np.random.seed(2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
def prepare_numpy_arrays(inputs_dict):
inputs_np_dict = {}
for k, v in inputs_dict.items():
if tf.is_tensor(v):
inputs_np_dict[k] = v.numpy()
else:
inputs_np_dict[k] = np.array(k)
return inputs_np_dict
for model_class in self.all_model_classes:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
inputs_np = prepare_numpy_arrays(inputs)
output_for_dict_input = model(inputs_np, noise=noise)
output_for_kw_input = model(**inputs_np, noise=noise)
self.assert_outputs_same(output_for_dict_input, output_for_kw_input)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViTMAE, the seq_len equals (number of patches + 1) * (1 - mask_ratio), rounded above
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViTMAE has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = int(math.ceil((1 - config.mask_ratio) * (num_patches + 1)))
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import torch
import transformers
# make masks reproducible
np.random.seed(2)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
pt_noise = torch.from_numpy(noise).to(device=torch_device)
tf_noise = tf.constant(noise)
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
pt_inputs_dict = {}
for name, key in tf_inputs_dict.items():
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
return pt_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
if isinstance(tf_outputs, np.float32):
tf_outputs = np.array(tf_outputs, dtype=np.float32)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model):
# we are not preparing a model with labels because of the formation
# of the ViT MAE model
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# Original test: check without `labels`
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
if self.has_attentions:
config.output_attentions = True
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
check_pt_tf_models(tf_model, pt_model)
# overwrite from common since TFViTMAEForPretraining outputs loss along with
# logits and mask indices. loss and mask indicies are not suitable for integration
# with other keras modules.
def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
for model_class in self.all_model_classes:
# `pixel_values` implies that the input is an image
inputs = tf.keras.Input(
batch_shape=(
3,
self.model_tester.num_channels,
self.model_tester.image_size,
self.model_tester.image_size,
),
name="pixel_values",
dtype="float32",
)
# Prepare our model
model = model_class(config)
model(self._prepare_for_class(inputs_dict, model_class)) # Model must be called before saving.
# Let's load it from the disk to be sure we can use pretrained weights
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(inputs)
hidden_states = outputs_dict[0]
# `TFViTMAEForPreTraining` outputs are not recommended to be used for
# downstream application. This is just to check if the outputs of
# `TFViTMAEForPreTraining` can be integrated with other keras modules.
if model_class.__name__ == "TFViTMAEForPreTraining":
hidden_states = outputs_dict["logits"]
# Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model
extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def test_keras_save_load(self):
# make mask reproducible
np.random.seed(2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
tf_main_layer_classes = set(
module_member
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
# This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
)
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
noise = tf.convert_to_tensor(noise)
inputs_dict.update({"noise": noise})
for main_layer_class in tf_main_layer_classes:
main_layer = main_layer_class(config)
symbolic_inputs = {
name: tf.keras.Input(tensor.shape[1:], dtype=tensor.dtype) for name, tensor in inputs_dict.items()
}
model = tf.keras.Model(symbolic_inputs, outputs=main_layer(symbolic_inputs))
outputs = model(inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
filepath = os.path.join(tmpdirname, "keras_model.h5")
model.save(filepath)
model = tf.keras.models.load_model(
filepath, custom_objects={main_layer_class.__name__: main_layer_class}
)
assert isinstance(model, tf.keras.Model)
after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def test_save_load(self):
# make mask reproducible
np.random.seed(2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
for model_class in self.all_model_classes:
model = model_class(config)
model_input = self._prepare_for_class(inputs_dict, model_class)
outputs = model(model_input, noise=noise)
if model_class.__name__ == "TFViTMAEModel":
out_2 = outputs.last_hidden_state.numpy()
out_2[np.isnan(out_2)] = 0
else:
out_2 = outputs.logits.numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
after_outputs = model(model_input, noise=noise)
if model_class.__name__ == "TFViTMAEModel":
out_1 = after_outputs["last_hidden_state"].numpy()
out_1[np.isnan(out_1)] = 0
else:
out_1 = after_outputs["logits"].numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
def test_save_load_config(self):
# make mask reproducible
np.random.seed(2)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
for model_class in self.all_model_classes:
model = model_class(config)
model_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(model_inputs, noise=noise)
model_config = model.get_config()
# make sure that returned config is jsonifiable, which is required by keras
json.dumps(model_config)
new_model = model_class.from_config(model.get_config())
# make sure it also accepts a normal config
_ = model_class.from_config(model.config)
_ = new_model(model_inputs) # Build model
new_model.set_weights(model.get_weights())
after_outputs = new_model(model_inputs, noise=noise)
self.assert_outputs_same(after_outputs, outputs)
@unittest.skip(
reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load
to get deterministic results."""
)
def test_determinism(self):
pass
@unittest.skip(reason="""ViTMAE returns a random mask + ids_restore in each forward pass. See test_save_load""")
def test_model_outputs_equivalence(self):
pass
@slow
def test_model_from_pretrained(self):
model = TFViTMAEModel.from_pretrained("google/vit-base-patch16-224")
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_tf
@require_vision
class TFViTMAEModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return ViTFeatureExtractor.from_pretrained("facebook/vit-mae-base") if is_vision_available() else None
@slow
def test_inference_for_pretraining(self):
# make random mask reproducible across the PT and TF model
np.random.seed(2)
model = TFViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
noise = np.random.uniform(size=(1, num_patches))
# forward pass
outputs = model(**inputs, noise=noise)
# verify the logits
expected_shape = tf.convert_to_tensor([1, 196, 768])
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.convert_to_tensor(
[[-0.0548, -1.7023, -0.9325], [0.3721, -0.5670, -0.2233], [0.8235, -1.3878, -0.3524]]
)
tf.debugging.assert_near(outputs.logits[0, :3, :3], expected_slice, atol=1e-4)
...@@ -17,13 +17,14 @@ ...@@ -17,13 +17,14 @@
import inspect import inspect
import math import math
import os
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
from transformers import ViTMAEConfig from transformers import ViTMAEConfig
from transformers.testing_utils import require_torch, require_vision, slow, torch_device from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
...@@ -139,11 +140,7 @@ class ViTMAEModelTester: ...@@ -139,11 +140,7 @@ class ViTMAEModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( config, pixel_values, labels = config_and_inputs
config,
pixel_values,
labels,
) = config_and_inputs
inputs_dict = {"pixel_values": pixel_values} inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict return config, inputs_dict
...@@ -322,6 +319,153 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -322,6 +319,153 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
check_hidden_states_output(inputs_dict, config, model_class) check_hidden_states_output(inputs_dict, config, model_class)
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
# make masks reproducible
np.random.seed(2)
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
pt_noise = torch.from_numpy(noise).to(device=torch_device)
tf_noise = tf.constant(noise)
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
tf_inputs_dict = {}
for key, tensor in pt_inputs_dict.items():
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
return tf_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
if isinstance(tf_outputs, np.float32):
tf_outputs = np.array(tf_outputs, dtype=np.float32)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict):
# we are not preparing a model with labels because of the formation
# of the ViT MAE model
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# Original test: check without `labels`
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
tf_model_class = getattr(transformers, tf_model_class_name)
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
# Check we can load pt model in tf and vice-versa with model => model functions
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
def test_save_load(self): def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -400,11 +544,8 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -400,11 +544,8 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
@slow @slow
def test_inference_for_pretraining(self): def test_inference_for_pretraining(self):
# make random mask reproducible # make random mask reproducible across the PT and TF model
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences, np.random.seed(2)
# as they both have fairly different PRNGs (for efficiency reasons).
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
torch.manual_seed(2)
model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device) model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base").to(torch_device)
...@@ -412,22 +553,22 @@ class ViTMAEModelIntegrationTest(unittest.TestCase): ...@@ -412,22 +553,22 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
image = prepare_img() image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device)
# prepare a noise vector that will be also used for testing the TF model
# (this way we can ensure that the PT and TF models operate on the same inputs)
vit_mae_config = ViTMAEConfig()
num_patches = int((vit_mae_config.image_size // vit_mae_config.patch_size) ** 2)
noise = np.random.uniform(size=(1, num_patches))
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs, noise=torch.from_numpy(noise))
# verify the logits # verify the logits
expected_shape = torch.Size((1, 196, 768)) expected_shape = torch.Size((1, 196, 768))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice_cpu = torch.tensor( expected_slice = torch.tensor(
[[0.7366, -1.3663, -0.2844], [0.7919, -1.3839, -0.3241], [0.4313, -0.7168, -0.2878]] [[-0.0548, -1.7023, -0.9325], [0.3721, -0.5670, -0.2233], [0.8235, -1.3878, -0.3524]]
)
expected_slice_gpu = torch.tensor(
[[0.8948, -1.0680, 0.0030], [0.9758, -1.1181, -0.0290], [1.0602, -1.1522, -0.0528]]
) )
# set expected slice depending on device
expected_slice = expected_slice_cpu if torch_device == "cpu" else expected_slice_gpu
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_slice.to(torch_device), atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment