Unverified Commit be4a6c64 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFViTModel (#13778)



* Start the work for TFViTModel

* Convert to TF code - need to check in the follow up commits

* Clean up model code

* Expose TFViTModel

* make style

* make quality

* Add test

* make style & quality

* Fix some imports

* fix wrong usage - *kwargs => ** kwargs

* Fix Conv2D weight loading (PT->TF) issue

* Add tests for images with different sizes + fix model

* Fix some common tests for TFViTModel

* Use inputs instead of input_ids in test_compile_tf_model

* Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name

* Avoid transpose in TFViT call

* Fix Conv2D issue in load_tf2_weights_in_pytorch_model

* Use tf.keras.layers.Conv2D instead of tf.nn.conv2d

* Using simpler heuristic to detect Conv2D layer

* Change convert_tf_weight_name_to_pt_weight_name to return TransposeType

* Check tf_weight_shape is not None before using it

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix missing comma

* fix input dtype
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6326aa4b
...@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -503,7 +503,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ | | VisualBert | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| ViT | ❌ | ❌ | ✅ | | ✅ | | ViT | ❌ | ❌ | ✅ | | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
......
...@@ -223,6 +223,13 @@ TFAutoModelForCausalLM ...@@ -223,6 +223,13 @@ TFAutoModelForCausalLM
:members: :members:
TFAutoModelForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFAutoModelForImageClassification
:members:
TFAutoModelForMaskedLM TFAutoModelForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -120,6 +120,20 @@ ViTForImageClassification ...@@ -120,6 +120,20 @@ ViTForImageClassification
:members: forward :members: forward
TFViTModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTModel
:members: call
TFViTForImageClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFViTForImageClassification
:members: call
FlaxVitModel FlaxVitModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -1396,6 +1396,7 @@ if is_tf_available(): ...@@ -1396,6 +1396,7 @@ if is_tf_available():
_import_structure["models.auto"].extend( _import_structure["models.auto"].extend(
[ [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -1408,6 +1409,7 @@ if is_tf_available(): ...@@ -1408,6 +1409,7 @@ if is_tf_available():
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -1734,6 +1736,13 @@ if is_tf_available(): ...@@ -1734,6 +1736,13 @@ if is_tf_available():
"TFTransfoXLPreTrainedModel", "TFTransfoXLPreTrainedModel",
] ]
) )
_import_structure["models.vit"].extend(
[
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
)
_import_structure["models.wav2vec2"].extend( _import_structure["models.wav2vec2"].extend(
[ [
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -3133,6 +3142,7 @@ if TYPE_CHECKING: ...@@ -3133,6 +3142,7 @@ if TYPE_CHECKING:
) )
from .models.auto import ( from .models.auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -3145,6 +3155,7 @@ if TYPE_CHECKING: ...@@ -3145,6 +3155,7 @@ if TYPE_CHECKING:
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
...@@ -3406,6 +3417,7 @@ if TYPE_CHECKING: ...@@ -3406,6 +3417,7 @@ if TYPE_CHECKING:
TFTransfoXLModel, TFTransfoXLModel,
TFTransfoXLPreTrainedModel, TFTransfoXLPreTrainedModel,
) )
from .models.vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
from .models.wav2vec2 import ( from .models.wav2vec2 import (
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
TFWav2Vec2ForCTC, TFWav2Vec2ForCTC,
......
...@@ -21,13 +21,24 @@ import re ...@@ -21,13 +21,24 @@ import re
import numpy import numpy
from .file_utils import ExplicitEnum
from .utils import logging from .utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove=""): class TransposeType(ExplicitEnum):
"""
Possible ...
"""
NO = "no"
SIMPLE = "simple"
CONV2D = "conv2d"
def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="", tf_weight_shape=None):
""" """
Convert a TF 2.0 model variable name in a pytorch model weight name. Convert a TF 2.0 model variable name in a pytorch model weight name.
...@@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ...@@ -39,8 +50,8 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
return tuple with: return tuple with:
- pytorch model weight name - pytorch model weight name
- transpose: boolean indicating whether TF2.0 and PyTorch weights matrices are transposed with regards to each - transpose: `TransposeType` member indicating whether and how TF2.0 and PyTorch weights matrices should be
other transposed with regards to each other
""" """
tf_name = tf_name.replace(":0", "") # device ids tf_name = tf_name.replace(":0", "") # device ids
tf_name = re.sub( tf_name = re.sub(
...@@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="") ...@@ -56,11 +67,17 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="")
tf_name = tf_name[1:] # Remove level zero tf_name = tf_name[1:] # Remove level zero
# When should we transpose the weights # When should we transpose the weights
transpose = bool( if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
# A simple heuristic to detect conv layer using weight array shape
transpose = TransposeType.CONV2D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"] tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
or "emb_projs" in tf_name or "emb_projs" in tf_name
or "out_projs" in tf_name or "out_projs" in tf_name
) ):
transpose = TransposeType.SIMPLE
else:
transpose = TransposeType.NO
# Convert standard TF2.0 names in PyTorch names # Convert standard TF2.0 names in PyTorch names
if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma": if tf_name[-1] == "kernel" or tf_name[-1] == "embeddings" or tf_name[-1] == "gamma":
...@@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -165,7 +182,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
sw_name, start_prefix_to_remove=start_prefix_to_remove sw_name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=symbolic_weight.shape
) )
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
...@@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a ...@@ -182,7 +199,12 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
array = pt_state_dict[name].numpy() array = pt_state_dict[name].numpy()
if transpose: if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 3, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array) array = numpy.transpose(array)
if len(symbolic_weight.shape) < len(array.shape): if len(symbolic_weight.shape) < len(array.shape):
...@@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -326,7 +348,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
tf_weights_map = {} tf_weights_map = {}
for tf_weight in tf_weights: for tf_weight in tf_weights:
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name( pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
) )
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose) tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
...@@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F ...@@ -350,7 +372,12 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
array, transpose = tf_weights_map[pt_weight_name] array, transpose = tf_weights_map[pt_weight_name]
if transpose: if transpose is TransposeType.CONV2D:
# Conv2D weight:
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
array = numpy.transpose(array, axes=(3, 2, 0, 1))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array) array = numpy.transpose(array)
if len(pt_weight.shape) < len(array.shape): if len(pt_weight.shape) < len(array.shape):
......
...@@ -73,6 +73,7 @@ if is_torch_available(): ...@@ -73,6 +73,7 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_auto"] = [ _import_structure["modeling_tf_auto"] = [
"TF_MODEL_FOR_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_CAUSAL_LM_MAPPING",
"TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"TF_MODEL_FOR_MASKED_LM_MAPPING", "TF_MODEL_FOR_MASKED_LM_MAPPING",
"TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING",
"TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING",
...@@ -85,6 +86,7 @@ if is_tf_available(): ...@@ -85,6 +86,7 @@ if is_tf_available():
"TF_MODEL_WITH_LM_HEAD_MAPPING", "TF_MODEL_WITH_LM_HEAD_MAPPING",
"TFAutoModel", "TFAutoModel",
"TFAutoModelForCausalLM", "TFAutoModelForCausalLM",
"TFAutoModelForImageClassification",
"TFAutoModelForMaskedLM", "TFAutoModelForMaskedLM",
"TFAutoModelForMultipleChoice", "TFAutoModelForMultipleChoice",
"TFAutoModelForPreTraining", "TFAutoModelForPreTraining",
...@@ -175,6 +177,7 @@ if TYPE_CHECKING: ...@@ -175,6 +177,7 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_auto import ( from .modeling_tf_auto import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -187,6 +190,7 @@ if TYPE_CHECKING: ...@@ -187,6 +190,7 @@ if TYPE_CHECKING:
TF_MODEL_WITH_LM_HEAD_MAPPING, TF_MODEL_WITH_LM_HEAD_MAPPING,
TFAutoModel, TFAutoModel,
TFAutoModelForCausalLM, TFAutoModelForCausalLM,
TFAutoModelForImageClassification,
TFAutoModelForMaskedLM, TFAutoModelForMaskedLM,
TFAutoModelForMultipleChoice, TFAutoModelForMultipleChoice,
TFAutoModelForPreTraining, TFAutoModelForPreTraining,
......
...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -64,6 +64,7 @@ TF_MODEL_MAPPING_NAMES = OrderedDict(
("pegasus", "TFPegasusModel"), ("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"), ("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"), ("blenderbot-small", "TFBlenderbotSmallModel"),
("vit", "TFViTModel"),
("wav2vec2", "TFWav2Vec2Model"), ("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"), ("hubert", "TFHubertModel"),
] ]
...@@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -144,6 +145,13 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
] ]
) )
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("vit", "TFViTForImageClassification"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[ [
# Model for Masked LM mapping # Model for Masked LM mapping
...@@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES ...@@ -302,6 +310,9 @@ TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES) TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES) TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES) TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
...@@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass): ...@@ -352,6 +363,13 @@ class TFAutoModelForCausalLM(_BaseAutoModelClass):
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling") TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
AutoModelForImageClassification = auto_class_update(TFAutoModelForImageClassification, head_doc="image classification")
class TFAutoModelForMaskedLM(_BaseAutoModelClass): class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...file_utils import _LazyModule, is_flax_available, is_torch_available, is_vision_available from ...file_utils import _LazyModule, is_flax_available, is_tf_available, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
...@@ -35,6 +35,12 @@ if is_torch_available(): ...@@ -35,6 +35,12 @@ if is_torch_available():
"ViTPreTrainedModel", "ViTPreTrainedModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_vit"] = [
"TFViTForImageClassification",
"TFViTModel",
"TFViTPreTrainedModel",
]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vit"] = [ _import_structure["modeling_flax_vit"] = [
...@@ -57,6 +63,9 @@ if TYPE_CHECKING: ...@@ -57,6 +63,9 @@ if TYPE_CHECKING:
ViTPreTrainedModel, ViTPreTrainedModel,
) )
if is_tf_available():
from .modeling_tf_vit import TFViTForImageClassification, TFViTModel, TFViTPreTrainedModel
if is_flax_available(): if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
......
This diff is collapsed.
...@@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel: ...@@ -176,6 +176,9 @@ class TFAlbertPreTrainedModel:
TF_MODEL_FOR_CAUSAL_LM_MAPPING = None TF_MODEL_FOR_CAUSAL_LM_MAPPING = None
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None
TF_MODEL_FOR_MASKED_LM_MAPPING = None TF_MODEL_FOR_MASKED_LM_MAPPING = None
...@@ -224,6 +227,15 @@ class TFAutoModelForCausalLM: ...@@ -224,6 +227,15 @@ class TFAutoModelForCausalLM:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFAutoModelForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFAutoModelForMaskedLM: class TFAutoModelForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
...@@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel: ...@@ -1971,6 +1983,29 @@ class TFTransfoXLPreTrainedModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFViTForImageClassification:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFViTModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -1476,6 +1476,8 @@ class ModelTesterMixin: ...@@ -1476,6 +1476,8 @@ class ModelTesterMixin:
tf_inputs_dict[key] = tensor tf_inputs_dict[key] = tensor
elif key == "input_values": elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else: else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
...@@ -1525,6 +1527,8 @@ class ModelTesterMixin: ...@@ -1525,6 +1527,8 @@ class ModelTesterMixin:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values": elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else: else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
......
...@@ -49,6 +49,7 @@ if is_tf_available(): ...@@ -49,6 +49,7 @@ if is_tf_available():
from transformers import ( from transformers import (
TF_MODEL_FOR_CAUSAL_LM_MAPPING, TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
...@@ -126,7 +127,10 @@ class TFModelTesterMixin: ...@@ -126,7 +127,10 @@ class TFModelTesterMixin:
elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING): elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): elif model_class in [
*get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*get_values(TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]:
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING): elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
...@@ -460,6 +464,8 @@ class TFModelTesterMixin: ...@@ -460,6 +464,8 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = key pt_inputs_dict[name] = key
elif name == "input_values": elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
...@@ -504,6 +510,8 @@ class TFModelTesterMixin: ...@@ -504,6 +510,8 @@ class TFModelTesterMixin:
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values": elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else: else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long) pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch # need to rename encoder-decoder "inputs" for PyTorch
...@@ -605,7 +613,7 @@ class TFModelTesterMixin: ...@@ -605,7 +613,7 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if self.is_encoder_decoder: if self.is_encoder_decoder:
input_ids = { inputs = {
"decoder_input_ids": tf.keras.Input( "decoder_input_ids": tf.keras.Input(
batch_shape=(2, max_input), batch_shape=(2, max_input),
name="decoder_input_ids", name="decoder_input_ids",
...@@ -613,10 +621,22 @@ class TFModelTesterMixin: ...@@ -613,10 +621,22 @@ class TFModelTesterMixin:
), ),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
# TODO: A better way to handle vision models
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]:
inputs = tf.keras.Input(
batch_shape=(
3,
self.model_tester.num_channels,
self.model_tester.image_size,
self.model_tester.image_size,
),
name="pixel_values",
dtype="float32",
)
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else: else:
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") inputs = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
# Prepare our model # Prepare our model
model = model_class(config) model = model_class(config)
...@@ -626,14 +646,14 @@ class TFModelTesterMixin: ...@@ -626,14 +646,14 @@ class TFModelTesterMixin:
model.save_pretrained(tmpdirname, saved_model=False) model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
outputs_dict = model(input_ids) outputs_dict = model(inputs)
hidden_states = outputs_dict[0] hidden_states = outputs_dict[0]
# Add a dense layer on top to test integration with other keras modules # Add a dense layer on top to test integration with other keras modules
outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states) outputs = tf.keras.layers.Dense(2, activation="softmax", name="outputs")(hidden_states)
# Compile extended model # Compile extended model
extended_model = tf.keras.Model(inputs=[input_ids], outputs=[outputs]) extended_model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric]) extended_model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
def test_keyword_and_dict_args(self): def test_keyword_and_dict_args(self):
...@@ -647,6 +667,8 @@ class TFModelTesterMixin: ...@@ -647,6 +667,8 @@ class TFModelTesterMixin:
inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) inputs_keywords = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
input_ids = inputs_keywords.pop("input_ids", None) input_ids = inputs_keywords.pop("input_ids", None)
if input_ids is None:
input_ids = inputs_keywords.pop("pixel_values", None)
outputs_keywords = model(input_ids, **inputs_keywords) outputs_keywords = model(input_ids, **inputs_keywords)
output_dict = outputs_dict[0].numpy() output_dict = outputs_dict[0].numpy()
output_keywords = outputs_keywords[0].numpy() output_keywords = outputs_keywords[0].numpy()
...@@ -1236,7 +1258,8 @@ class TFModelTesterMixin: ...@@ -1236,7 +1258,8 @@ class TFModelTesterMixin:
# Test that model correctly compute the loss with kwargs # Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
input_ids = prepared_for_class.pop("input_ids") input_name = "input_ids" if "input_ids" in prepared_for_class else "pixel_values"
input_ids = prepared_for_class.pop(input_name)
loss = model(input_ids, **prepared_for_class)[0] loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size]) self.assertEqual(loss.shape, [loss_size])
...@@ -1255,7 +1278,7 @@ class TFModelTesterMixin: ...@@ -1255,7 +1278,7 @@ class TFModelTesterMixin:
signature_names = list(signature.keys()) signature_names = list(signature.keys())
# Create a dictionary holding the location of the tensors in the tuple # Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {0: "input_ids"} tuple_index_mapping = {0: input_name}
for label_key in label_keys: for label_key in label_keys:
label_key_index = signature_names.index(label_key) label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key tuple_index_mapping[label_key_index] = label_key
......
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow ViT model. """
import inspect
import os
import tempfile
import unittest
from transformers import ViTConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, require_vision, slow, tooslow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
if is_tf_available():
import tensorflow as tf
from transformers import TFViTForImageClassification, TFViTModel
from transformers.models.vit.modeling_tf_vit import to_2tuple
if is_vision_available():
from PIL import Image
from transformers import ViTFeatureExtractor
class TFViTModelTester:
def __init__(
self,
parent,
batch_size=13,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
use_labels=True,
hidden_size=32,
num_hidden_layers=5,
num_attention_heads=4,
intermediate_size=37,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
type_sequence_label_size=10,
initializer_range=0.02,
num_labels=3,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.use_labels = use_labels
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.scope = scope
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ViTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
is_decoder=False,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values, labels):
model = TFViTModel(config=config)
result = model(pixel_values, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.image_size)
patch_size = to_2tuple(self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(image_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches + 1, self.hidden_size))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
model = TFViTForImageClassification(config)
result = model(pixel_values, labels=labels, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
# Test with an image with different size than the one specified in config.
image_size = self.image_size // 2
pixel_values = pixel_values[:, :, :image_size, :image_size]
result = model(pixel_values, interpolate_pos_encoding=True, training=False)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.type_sequence_label_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_tf
class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_tf_common.py, as ViT does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (TFViTModel, TFViTForImageClassification) if is_tf_available() else ()
test_resize_embeddings = False
test_head_masking = False
test_onnx = False
def setUp(self):
self.model_tester = TFViTModelTester(self)
self.config_tester = ConfigTester(self, config_class=ViTConfig, has_text_modality=False, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_inputs_embeds(self):
# ViT does not use inputs_embeds
pass
def test_graph_mode_with_inputs_embeds(self):
# ViT does not use inputs_embeds
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (tf.keras.layers.Layer))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, tf.keras.layers.Layer))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.call)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated
# in a different way than in text models.
@tooslow
def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True
config.output_attentions = True
if hasattr(config, "use_cache"):
config.use_cache = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
num_out = len(model(class_inputs_dict))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=True)
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
model = tf.keras.models.load_model(saved_model_dir)
outputs = model(class_inputs_dict)
if self.is_encoder_decoder:
output_hidden_states = outputs["encoder_hidden_states"]
output_attentions = outputs["encoder_attentions"]
else:
output_hidden_states = outputs["hidden_states"]
output_attentions = outputs["attentions"]
self.assertEqual(len(outputs), num_out)
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(output_hidden_states), expected_num_layers)
self.assertListEqual(
list(output_hidden_states[0].shape[-2:]),
[seq_len, self.model_tester.hidden_size],
)
self.assertEqual(len(output_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(output_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
# in ViT, the seq_len equals the number of patches + 1 (we add 1 for the [CLS] token)
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_len = num_patches + 1
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class), training=False)
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config)
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
expected_num_layers = getattr(
self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
)
self.assertEqual(len(hidden_states), expected_num_layers)
# ViT has a different seq_length
image_size = to_2tuple(self.model_tester.image_size)
patch_size = to_2tuple(self.model_tester.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
seq_length = num_patches + 1
self.assertListEqual(
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs_dict["output_hidden_states"] = True
check_hidden_states_output(inputs_dict, config, model_class)
# check that output_hidden_states also work using config
del inputs_dict["output_hidden_states"]
config.output_hidden_states = True
check_hidden_states_output(inputs_dict, config, model_class)
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
model = TFViTModel.from_pretrained("google/vit-base-patch16-224", from_pt=True)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
return image
@require_vision
class TFViTModelIntegrationTest(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") if is_vision_available() else None
@slow
def test_inference_image_classification_head(self):
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224", from_pt=True)
feature_extractor = self.default_feature_extractor
image = prepare_img()
inputs = feature_extractor(images=image, return_tensors="tf")
# forward pass
outputs = model(**inputs)
# verify the logits
expected_shape = tf.TensorShape((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = tf.constant([-0.2744, 0.8215, -0.0836])
tf.debugging.assert_near(outputs.logits[0, :3], expected_slice, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment