Unverified Commit be4a6c64 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFViTModel (#13778)



* Start the work for TFViTModel

* Convert to TF code - need to check in the follow up commits

* Clean up model code

* Expose TFViTModel

* make style

* make quality

* Add test

* make style & quality

* Fix some imports

* fix wrong usage - *kwargs => ** kwargs

* Fix Conv2D weight loading (PT->TF) issue

* Add tests for images with different sizes + fix model

* Fix some common tests for TFViTModel

* Use inputs instead of input_ids in test_compile_tf_model

* Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name

* Avoid transpose in TFViT call

* Fix Conv2D issue in load_tf2_weights_in_pytorch_model

* Use tf.keras.layers.Conv2D instead of tf.nn.conv2d

* Using simpler heuristic to detect Conv2D layer

* Change convert_tf_weight_name_to_pt_weight_name to return TransposeType

* Check tf_weight_shape is not None before using it

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix missing comma

* fix input dtype
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6326aa4b
...@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | | ✅ | | ViT | ❌ | ❌ | ✅ | | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -223,6 +223,13 @@ TFAutoModelForCausalLM ...@@ -223,6 +223,13 @@ TFAutoModelForCausalLM
:members: :members:
TFAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAutoModelForImageClassification
:members:
TFAutoModelForMaskedLM TFAutoModelForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -120,6 +120,20 @@ ViTForImageClassification ...@@ -120,6 +120,20 @@ ViTForImageClassification
:members: forward :members: forward
TFViTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTModel
:members: call
TFViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTForImageClassification
:members: call
FlaxVitModel FlaxVitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1396,6 +1396,7 @@ if is_tf_available(): ...@@ -1396,6 +1396,7 @@ if is_tf_available():
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -1408,6 +1409,7 @@ if is_tf_available(): ...@@ -1408,6 +1409,7 @@ if is_tf_available():
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -1734,6 +1736,13 @@ if is_tf_available(): ...@@ -1734,6 +1736,13 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel", "TFTransfoXLPreTrainedModel",
] ]
) )
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -3133,6 +3142,7 @@ if TYPE_CHECKING: ...@@ -3133,6 +3142,7 @@ if TYPE_CHECKING:
) )
from .models.auto import ( from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -3145,6 +3155,7 @@ if TYPE_CHECKING: ...@@ -3145,6 +3155,7 @@ if TYPE_CHECKING:
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
...@@ -3406,6 +3417,7 @@ if TYPE_CHECKING: ...@@ -3406,6 +3417,7 @@ if TYPE_CHECKING:
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
......
...@@ -21,13 +21,24 @@ import re ...@@ -21,13 +21,24 @@ import re
import numpy import numpy
from .file_utils import ExplicitEnum
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""): class TransposeType(ExplicitEnum):
"""
Possible ...
"""
NO = "no"
SIMPLE = "simple"
CONV2D = "conv2d"
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None):
""" """
Convert a TF 2.0 model variable name in a pytorch model weight name. Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ...@@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
return tuple with: return tuple with:
- pytorch model weight name - pytorch model weight name
- transpose: boolean indicating whether TF2.0 and PyTorch weights matrices are transposed with regards to each - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
other transposed with regards to each other
""" """
tf_name = tf_name.replace(":0", "") # device ids tf_name = tf_name.replace(":0", "") # device ids
tf_name = re.sub( tf_name = re.sub(
...@@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ...@@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
tf_name = tf_name[1:] # Remove level zero tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights # When should we transpose the weights
transpose = bool( if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
# A simple heuristic to detect conv layer using weight array shape
transpose = TransposeType.CONV2D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
or "emb_projs" in tf_name or "emb_projs" in tf_name
or "out_projs" in tf_name or "out_projs" in tf_name
) ):
transpose = TransposeType.SIMPLE
else:
transpose = TransposeType.NO
# Convert standard TF2.0 names in PyTorch names # Convert standard TF2.0 names in PyTorch names
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
...@@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
sw_name, start_prefix_to_remove=start_prefix_to_remove sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape
) )
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
...@@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
array = pt_state_dict[name].numpy() array = pt_state_dict[name].numpy()
if transpose: if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 3, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array) array = numpy.transpose(array)
if len(symbolic_weight.shape) < len(array.shape): if len(symbolic_weight.shape) < len(array.shape):
...@@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
tf_weights_map = {} tf_weights_map = {}
for tf_weight in tf_weights: for tf_weight in tf_weights:
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
) )
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
...@@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
array, transpose = tf_weights_map[pt_weight_name] array, transpose = tf_weights_map[pt_weight_name]
if transpose: if transpose is TransposeType.CONV2D:
# Conv2D weight:
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
array = numpy.transpose(array, axes=(3, 2, 0, 1))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array) array = numpy.transpose(array)
if len(pt_weight.shape) < len(array.shape): if len(pt_weight.shape) < len(array.shape):
......
...@@ -73,6 +73,7 @@ if is_torch_available(): ...@@ -73,6 +73,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_auto"] = [ _import_structure["modeling_tf_auto"] = [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -85,6 +86,7 @@ if is_tf_available(): ...@@ -85,6 +86,7 @@ if is_tf_available():
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -175,6 +177,7 @@ if TYPE_CHECKING: ...@@ -175,6 +177,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_auto import ( from .modeling_tf_auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -187,6 +190,7 @@ if TYPE_CHECKING: ...@@ -187,6 +190,7 @@ if TYPE_CHECKING:
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
......
...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("pegasus", "TFPegasusModel"), ("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"), ("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"), ("blenderbot-small", "TFBlenderbotSmallModel"),
("vit", "TFViTModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
] ]
...@@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
] ]
) )
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
...@@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES ...@@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
...@@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass): ...@@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass):
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
class TFAutoModelForMaskedLM(_BaseAutoModelClass): class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available, is_vision_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
...@@ -35,6 +35,12 @@ if is_torch_available(): ...@@ -35,6 +35,12 @@ if is_torch_available():
"ViTPreTrainedModel", "ViTPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_vit"] = [
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vit"] = [ _import_structure["modeling_flax_vit"] = [
...@@ -57,6 +63,9 @@ if TYPE_CHECKING: ...@@ -57,6 +63,9 @@ if TYPE_CHECKING:
ViTPreTrainedModel, ViTPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
if is_flax_available(): if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
......
# coding=utf-8
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 ViT model. """
import collections.abc
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from ...activations_tf import get_tf_activation
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPooling, TFSequenceClassifierOutput
from ...modeling_tf_utils import (
TFModelInputType,
TFPreTrainedModel,
TFSequenceClassificationLoss,
get_initializer,
input_processing,
keras_serializable,
shape_list,
)
from ...utils import logging
from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
# Inspired by
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# From PyTorch internals
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFViTEmbeddings(tf.keras.layers.Layer):
"""
Construct the CLS token, position and patch embeddings.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = TFPatchEmbeddings(config, name="patch_embeddings")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
self.config = config
def build(self, input_shape: tf.TensorShape):
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight(
shape=(1, 1, self.config.hidden_size), initializer="zeros", trainable=True, name="cls_token"
)
self.position_embeddings = self.add_weight(
shape=(1, num_patches + 1, self.config.hidden_size),
initializer="zeros",
trainable=True,
name="position_embeddings",
)
super().build(input_shape)
def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
batch_size, seq_len, dim = shape_list(embeddings)
npatch = seq_len - 1
_, N, _ = shape_list(self.position_embeddings)
N -= 1
if npatch == N and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, int(math.sqrt(N)), int(math.sqrt(N)), dim)),
size=(h0, w0),
method="bicubic",
)
shape = shape_list(patch_pos_embed)
assert h0 == shape[-3] and w0 == shape[-2]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)
def call(
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, training=training
)
# add the [CLS] token to the embedded patch tokens
cls_tokens = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
embeddings = tf.concat((cls_tokens, embeddings), axis=1)
# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings
embeddings = self.dropout(embeddings, training=training)
return embeddings
# Based on timm implementation, which can be found here:
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
class TFPatchEmbeddings(tf.keras.layers.Layer):
"""
Image to Patch Embedding.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
image_size = to_2tuple(config.image_size)
patch_size = to_2tuple(config.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_patches = num_patches
self.num_channels = config.num_channels
self.embed_dim = config.hidden_size
self.config = config
self.projection = tf.keras.layers.Conv2D(
filters=self.embed_dim,
kernel_size=patch_size,
strides=self.patch_size,
padding="valid",
data_format="channels_last",
use_bias=True,
kernel_initializer=get_initializer(self.config.initializer_range),
bias_initializer="zeros",
name="projection",
)
def call(
self, pixel_values: tf.Tensor, interpolate_pos_encoding: bool = False, training: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if not interpolate_pos_encoding:
if getattr(height, "numpy", None) and getattr(width, "numpy", None):
if height != self.image_size[0] or width != self.image_size[1]:
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
)
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support `NCHW` format.
# So change the input format from `NCHW` to `NHWC`.
# shape = (batch_size, in_height, in_width, in_channels=num_channels)
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
projection = self.projection(pixel_values)
# Change the 2D spatial dimensions to a single temporal dimension.
# shape = (batch_size, num_patches, out_channels=embed_dim)
num_patches = (width // self.patch_size[1]) * (height // self.patch_size[0])
x = tf.reshape(tensor=projection, shape=(batch_size, num_patches, -1))
return x
class TFViTSelfAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
f"of attention heads ({config.num_attention_heads})"
)
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
)
self.key = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
)
self.value = tf.keras.layers.Dense(
units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)
def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
# Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size]
return tf.transpose(tensor, perm=[0, 2, 1, 3])
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True)
dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype)
attention_scores = tf.divide(attention_scores, dk)
# Normalize the attention scores to probabilities.
attention_probs = tf.nn.softmax(logits=attention_scores, axis=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(inputs=attention_probs, training=training)
# Mask heads if we want to
if head_mask is not None:
attention_probs = tf.multiply(attention_probs, head_mask)
attention_output = tf.matmul(attention_probs, value_layer)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
# (batch_size, seq_len_q, all_head_size)
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)
return outputs
class TFViTSelfOutput(tf.keras.layers.Layer):
"""
The residual connection is defined in TFViTLayer instead of here (as is the case with other models), due to the
layernorm applied before each block.
"""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
return hidden_states
class TFViTAttention(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFViTSelfAttention(config, name="attention")
self.dense_output = TFViTSelfOutput(config, name="output")
def prune_heads(self, heads):
raise NotImplementedError
def call(
self,
input_tensor: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor, head_mask=head_mask, output_attentions=output_attentions, training=training
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs
class TFViTIntermediate(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFViTOutput(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
)
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = hidden_states + input_tensor
return hidden_states
class TFViTLayer(tf.keras.layers.Layer):
"""This corresponds to the Block class in the timm implementation."""
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.attention = TFViTAttention(config, name="attention")
self.intermediate = TFViTIntermediate(config, name="intermediate")
self.vit_output = TFViTOutput(config, name="output")
self.layernorm_before = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_before"
)
self.layernorm_after = tf.keras.layers.LayerNormalization(
epsilon=config.layer_norm_eps, name="layernorm_after"
)
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
# in ViT, layernorm is applied before self-attention
input_tensor=self.layernorm_before(inputs=hidden_states),
head_mask=head_mask,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
# first residual connection
hidden_states = attention_output + hidden_states
# in ViT, layernorm is also applied after self-attention
layer_output = self.layernorm_after(inputs=hidden_states)
intermediate_output = self.intermediate(hidden_states=layer_output)
# second residual connection is done here
layer_output = self.vit_output(
hidden_states=intermediate_output, input_tensor=hidden_states, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
return outputs
class TFViTEncoder(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.layer = [TFViTLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]
def call(
self,
hidden_states: tf.Tensor,
head_mask: tf.Tensor,
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states=hidden_states,
head_mask=head_mask[i],
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
)
@keras_serializable
class TFViTMainLayer(tf.keras.layers.Layer):
config_class = ViTConfig
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, **kwargs):
super().__init__(**kwargs)
self.config = config
self.embeddings = TFViTEmbeddings(config, name="embeddings")
self.encoder = TFViTEncoder(config, name="encoder")
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layernorm")
self.pooler = TFViTPooler(config, name="pooler") if add_pooling_layer else None
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.patch_embeddings
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
raise NotImplementedError
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
if inputs["pixel_values"] is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(
pixel_values=inputs["pixel_values"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
training=inputs["training"],
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if inputs["head_mask"] is not None:
raise NotImplementedError
else:
inputs["head_mask"] = [None] * self.config.num_hidden_layers
encoder_outputs = self.encoder(
hidden_states=embedding_output,
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None
if not inputs["return_dict"]:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class TFViTPreTrainedModel(TFPreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = ViTConfig
base_model_prefix = "vit"
@property
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
"""
Dummy inputs to build the network.
Returns:
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32
)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}
@tf.function(
input_signature=[
{
"pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
}
]
)
def serving(self, inputs):
"""
Method used for serving the model.
Args:
inputs (:obj:`Dict[str, tf.Tensor]`):
The input of the saved model as a dictionary of tensors.
"""
output = self.call(inputs)
return self.serving_output(output)
VIT_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
.. note::
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
Args:
config (:class:`~transformers.ViTConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
model weights.
"""
VIT_INPUTS_DOCSTRING = r"""
Args:
pixel_values (:obj:`np.ndarray`, :obj:`tf.Tensor`, :obj:`List[tf.Tensor]` :obj:`Dict[str, tf.Tensor]` or :obj:`Dict[str, np.ndarray]` and each example must have the shape :obj:`(batch_size, num_channels, height, width)`):
Pixel values. Pixel values can be obtained using :class:`~transformers.ViTFeatureExtractor`. See
:meth:`transformers.ViTFeatureExtractor.__call__` for details.
head_mask (:obj:`np.ndarray` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
config will be used instead.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
used instead.
interpolate_pos_encoding (:obj:`bool`, `optional`):
Whether to interpolate the pre-trained position encodings.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. This
argument can be used in eager mode, in graph mode the value will always be set to True.
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
@add_start_docstrings(
"The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
VIT_START_DOCSTRING,
)
class TFViTModel(TFViTPreTrainedModel):
def __init__(self, config: ViTConfig, *inputs, add_pooling_layer=True, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.vit = TFViTMainLayer(config, add_pooling_layer=add_pooling_layer, name="vit")
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFBaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
**kwargs,
) -> Union[TFBaseModelOutputWithPooling, Tuple[tf.Tensor]]:
r"""
Returns:
Examples::
>>> from transformers import ViTFeatureExtractor, TFViTModel
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
>>> model = TFViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
return outputs
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFBaseModelOutputWithPooling(
last_hidden_state=output.last_hidden_state,
pooler_output=output.pooler_output,
hidden_states=hs,
attentions=attns,
)
class TFViTPooler(tf.keras.layers.Layer):
def __init__(self, config: ViTConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
@add_start_docstrings(
"""
ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
the [CLS] token) e.g. for ImageNet.
""",
VIT_START_DOCSTRING,
)
class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassificationLoss):
def __init__(self, config: ViTConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.vit = TFViTMainLayer(config, add_pooling_layer=False, name="vit")
# Classifier head
self.classifier = tf.keras.layers.Dense(
units=config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name="classifier",
)
@add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
pixel_values: Optional[TFModelInputType] = None,
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
training: Optional[bool] = False,
**kwargs,
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
r"""
labels (:obj:`tf.Tensor` or :obj:`np.ndarray` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
Examples::
>>> from transformers import ViTFeatureExtractor, TFViTForImageClassification
>>> from PIL import Image
>>> import requests
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
>>> model = TFViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
>>> inputs = feature_extractor(images=image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
"""
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=pixel_values,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
labels=labels,
training=training,
kwargs_call=kwargs,
)
if "input_ids" in inputs:
inputs["pixel_values"] = inputs.pop("input_ids")
outputs = self.vit(
pixel_values=inputs["pixel_values"],
head_mask=inputs["head_mask"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
interpolate_pos_encoding=inputs["interpolate_pos_encoding"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
loss = None if inputs["labels"] is None else self.compute_loss(labels=inputs["labels"], logits=logits)
if not inputs["return_dict"]:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns)
...@@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel: ...@@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel:
TF_MODEL_FOR_CAUSAL_LM_MAPPING = None TF_MODEL_FOR_CAUSAL_LM_MAPPING = None
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_MASKED_LM_MAPPING = None TF_MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -224,6 +227,15 @@ class TFAutoModelForCausalLM: ...@@ -224,6 +227,15 @@ class TFAutoModelForCausalLM:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFAutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFAutoModelForMaskedLM: class TFAutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel: ...@@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -1476,6 +1476,8 @@ class ModelTesterMixin: ...@@ -1476,6 +1476,8 @@ class ModelTesterMixin:
tf_inputs_dict[key] = tensor tf_inputs_dict[key] = tensor
elif key == "input_values": elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else: else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
...@@ -1525,6 +1527,8 @@ class ModelTesterMixin: ...@@ -1525,6 +1527,8 @@ class ModelTesterMixin:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values": elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else: else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
......
...@@ -49,6 +49,7 @@ if is_tf_available(): ...@@ -49,6 +49,7 @@ if is_tf_available():
from transformers import ( from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -126,7 +127,10 @@ class TFModelTesterMixin: ...@@ -126,7 +127,10 @@ class TFModelTesterMixin:
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING): elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): elif model_class in [
*get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING): elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
...@@ -460,6 +464,8 @@ class TFModelTesterMixin: ...@@ -460,6 +464,8 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = key pt_inputs_dict[name] = key
elif name == "input_values": elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
...@@ -504,6 +510,8 @@ class TFModelTesterMixin: ...@@ -504,6 +510,8 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values": elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
...@@ -605,7 +613,7 @@ class TFModelTesterMixin: ...@@ -605,7 +613,7 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if self.is_encoder_decoder: if self.is_encoder_decoder:
input_ids = { inputs = {
"decoder_input_ids": tf.keras.Input( "decoder_input_ids": tf.keras.Input(
batch_shape=(2, max_input), batch_shape=(2, max_input),
name="decoder_input_ids", name="decoder_input_ids",
...@@ -613,10 +621,22 @@ class TFModelTesterMixin: ...@@ -613,10 +621,22 @@ class TFModelTesterMixin:
), ),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
# TODO: A better way to handle vision models
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]:
inputs = tf.keras.Input(
batch_shape=(
3,
self.model_tester.num_channels,
self.model_tester.image_size,
self.model_tester.image_size,
),
name="pixel_values",
dtype="float32",
)
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else: else:
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") inputs = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
# Prepare our model # Prepare our model
model = model_class(config) model = model_class(config)
...@@ -626,14 +646,14 @@ class TFModelTesterMixin: ...@@ -626,14 +646,14 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname, saved_model=False) model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids) outputs_dict = model(inputs)
hidden_states = outputs_dict[0] hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules # Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model # Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_keyword_and_dict_args(self): def test_keyword_and_dict_args(self):
...@@ -647,6 +667,8 @@ class TFModelTesterMixin: ...@@ -647,6 +667,8 @@ class TFModelTesterMixin:
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs_keywords.pop("input_ids", None) input_ids = inputs_keywords.pop("input_ids", None)
if input_ids is None:
input_ids = inputs_keywords.pop("pixel_values", None)
outputs_keywords = model(input_ids, **inputs_keywords) outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy() output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy() output_keywords = outputs_keywords[0].numpy()
...@@ -1236,7 +1258,8 @@ class TFModelTesterMixin: ...@@ -1236,7 +1258,8 @@ class TFModelTesterMixin:
# Test that model correctly compute the loss with kwargs # Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
input_ids = prepared_for_class.pop("input_ids") input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values"
input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0] loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape, [loss_size])
...@@ -1255,7 +1278,7 @@ class TFModelTesterMixin: ...@@ -1255,7 +1278,7 @@ class TFModelTesterMixin:
signature_names = list(signature.keys()) signature_names = list(signature.keys())
# Create a dictionary holding the location of the tensors in the tuple # Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: "input_ids"} tuple_index_mapping = {0: input_name}
for label_key in label_keys: for label_key in label_keys:
label_key_index = signature_names.index(label_key) label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key tuple_index_mapping[label_key_index] = label_key
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow ViT model. """
import inspect
import os
import tempfile
import unittest
from transformers import ViTConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, require_vision, slow, tooslow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFViTForImageClassification, TFViTModel
from transformers.models.vit.modeling_tf_vit import to_2tuple
if is_vision_available():
from PIL import Image
from transformers import ViTFeatureExtractor
class TFViTModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ViTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTModel(config=config)
result = model(pixel_values, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(image_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFViTForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_tf_common.py, as ViT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFViTModel, TFViTForImageClassification) if is_tf_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFViTModelTester(self)
self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# ViT does not use inputs_embeds
pass
def test_graph_mode_with_inputs_embeds(self):
# ViT does not use inputs_embeds
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated
# in a different way than in text models.
@tooslow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
if hasattr(config, "use_cache"):
config.use_cache = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
if self.is_encoder_decoder:
output_hidden_states = outputs["encoder_hidden_states"]
output_attentions = outputs["encoder_attentions"]
else:
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]
self.assertEqual(len(outputs), num_out)
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(output_hidden_states), expected_num_layers)
self.assertListEqual(
list(output_hidden_states[0].shape[-2:]),
[seq_len, self.model_tester.hidden_size],
)
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=True)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
class TFViTModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.2744, 0.8215, -0.0836])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment