"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "35236b870ee4102fb82f7a1d4713dc83af78a00a"
Unverified Commit 8f2cc1c3 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Add TFCLIPModel (#13967)



* Start the work for TFCLIPModel

* Convert to TF code (TODO: loss + doc)

* Clean up

* Fix pooled_output for TFCLIPTextTransformer - using tf.gather_nd

* assert -> raise error

* Expose TFCLIPModel

* Deal with dummy_inputs

* Add tests

* Fix all tests. TODO: manual check weight loading + add more comments

* Fix pt tf equivalence test

* fixes

* update TFCLIPVisionEmbeddings's Conv2D

* Fix loss + overwrite test_pt_tf_model_equivalence from common

* Add a comment about the change about MainLayer in test_keras_save_load

* Set return_loss=True in TFCLIPModelTester + make tests pass

* overwrite test_pt_tf_model_equivalence from tf common

* fix base_model_prefix

* Fix examples

* remove unused

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply review suggestions

* change self.pre_layrnorm to self.pre_layernorm

* apply more review suggestions

* return attention probs before dropout (to align with PT)

* fix weight init

* fix

* build doc

* fix missing doc

* fix for test
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 2d30443c
...@@ -202,7 +202,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -202,7 +202,7 @@ Flax), PyTorch, and/or TensorFlow.
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| Canine | ✅ | ❌ | ✅ | ❌ | ❌ | | Canine | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | | ✅ | | CLIP | ✅ | ✅ | ✅ | | ✅ |
| ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | ConvBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| CTRL | ✅ | ❌ | ✅ | ✅ | ❌ | | CTRL | ✅ | ❌ | ✅ | ✅ | ❌ |
| DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ | | DeBERTa | ✅ | ✅ | ✅ | ✅ | ❌ |
......
...@@ -125,6 +125,23 @@ This model was contributed by [valhalla](https://huggingface.co/valhalla). The o ...@@ -125,6 +125,23 @@ This model was contributed by [valhalla](https://huggingface.co/valhalla). The o
[[autodoc]] CLIPVisionModel [[autodoc]] CLIPVisionModel
- forward - forward
## TFCLIPModel
[[autodoc]] TFCLIPModel
- call
- get_text_features
- get_image_features
## TFCLIPTextModel
[[autodoc]] TFCLIPTextModel
- call
## TFCLIPVisionModel
[[autodoc]] TFCLIPVisionModel
- call
## FlaxCLIPModel ## FlaxCLIPModel
[[autodoc]] FlaxCLIPModel [[autodoc]] FlaxCLIPModel
......
...@@ -1549,6 +1549,15 @@ if is_tf_available(): ...@@ -1549,6 +1549,15 @@ if is_tf_available():
"TFCamembertModel", "TFCamembertModel",
] ]
) )
_import_structure["models.clip"].extend(
[
"TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCLIPModel",
"TFCLIPPreTrainedModel",
"TFCLIPTextModel",
"TFCLIPVisionModel",
]
)
_import_structure["models.convbert"].extend( _import_structure["models.convbert"].extend(
[ [
"TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -3394,6 +3403,13 @@ if TYPE_CHECKING: ...@@ -3394,6 +3403,13 @@ if TYPE_CHECKING:
TFCamembertForTokenClassification, TFCamembertForTokenClassification,
TFCamembertModel, TFCamembertModel,
) )
from .models.clip import (
TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCLIPModel,
TFCLIPPreTrainedModel,
TFCLIPTextModel,
TFCLIPVisionModel,
)
from .models.convbert import ( from .models.convbert import (
TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFConvBertForMaskedLM, TFConvBertForMaskedLM,
......
...@@ -63,6 +63,12 @@ def gelu_fast(x): ...@@ -63,6 +63,12 @@ def gelu_fast(x):
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
def quick_gelu(x):
x = tf.convert_to_tensor(x)
coeff = tf.cast(1.702, x.dtype)
return x * tf.math.sigmoid(coeff * x)
if version.parse(tf.version.VERSION) >= version.parse("2.4"): if version.parse(tf.version.VERSION) >= version.parse("2.4"):
def approximate_gelu_wrap(x): def approximate_gelu_wrap(x):
...@@ -84,6 +90,7 @@ ACT2FN = { ...@@ -84,6 +90,7 @@ ACT2FN = {
"mish": mish, "mish": mish,
"tanh": tf.keras.activations.tanh, "tanh": tf.keras.activations.tanh,
"gelu_fast": gelu_fast, "gelu_fast": gelu_fast,
"quick_gelu": quick_gelu,
} }
......
...@@ -29,6 +29,7 @@ logger = logging.get_logger(__name__) ...@@ -29,6 +29,7 @@ logger = logging.get_logger(__name__)
TF_MODEL_MAPPING_NAMES = OrderedDict( TF_MODEL_MAPPING_NAMES = OrderedDict(
[ [
# Base model mapping # Base model mapping
("clip", "TFCLIPModel"),
("deberta-v2", "TFDebertaV2Model"), ("deberta-v2", "TFDebertaV2Model"),
("deberta", "TFDebertaModel"), ("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"), ("rembert", "TFRemBertModel"),
......
...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING ...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import ( from ...file_utils import (
_LazyModule, _LazyModule,
is_flax_available, is_flax_available,
is_tf_available,
is_tokenizers_available, is_tokenizers_available,
is_torch_available, is_torch_available,
is_vision_available, is_vision_available,
...@@ -47,6 +48,15 @@ if is_torch_available(): ...@@ -47,6 +48,15 @@ if is_torch_available():
"CLIPVisionModel", "CLIPVisionModel",
] ]
if is_tf_available():
_import_structure["modeling_tf_clip"] = [
"TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST",
"TFCLIPModel",
"TFCLIPPreTrainedModel",
"TFCLIPTextModel",
"TFCLIPVisionModel",
]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_clip"] = [ _import_structure["modeling_flax_clip"] = [
"FlaxCLIPModel", "FlaxCLIPModel",
...@@ -78,6 +88,15 @@ if TYPE_CHECKING: ...@@ -78,6 +88,15 @@ if TYPE_CHECKING:
CLIPVisionModel, CLIPVisionModel,
) )
if is_tf_available():
from .modeling_tf_clip import (
TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCLIPModel,
TFCLIPPreTrainedModel,
TFCLIPTextModel,
TFCLIPVisionModel,
)
if is_flax_available(): if is_flax_available():
from .modeling_flax_clip import ( from .modeling_flax_clip import (
FlaxCLIPModel, FlaxCLIPModel,
......
This diff is collapsed.
...@@ -704,6 +704,57 @@ class TFCamembertModel: ...@@ -704,6 +704,57 @@ class TFCamembertModel:
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TFCLIPModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFCLIPPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFCLIPTextModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFCLIPVisionModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
def call(self, *args, **kwargs):
requires_backends(self, ["tf"])
TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -29,6 +29,7 @@ from transformers.file_utils import is_torch_available, is_vision_available ...@@ -29,6 +29,7 @@ from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flax_available, is_flax_available,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_torch, require_torch,
require_vision, require_vision,
slow, slow,
...@@ -581,6 +582,148 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -581,6 +582,148 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
self.assertTrue(models_equal) self.assertTrue(models_equal)
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return
tf_model_class = getattr(transformers, tf_model_class_name)
config.output_hidden_states = True
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict)
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output # overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test # which is not supported in the common test
@is_pt_flax_cross_test @is_pt_flax_cross_test
......
This diff is collapsed.
...@@ -282,6 +282,8 @@ class TFModelTesterMixin: ...@@ -282,6 +282,8 @@ class TFModelTesterMixin:
for module in (import_module(model_class.__module__),) for module in (import_module(model_class.__module__),)
for module_member_name in dir(module) for module_member_name in dir(module)
if module_member_name.endswith("MainLayer") if module_member_name.endswith("MainLayer")
# This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
for module_member in (getattr(module, module_member_name),) for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__ and tf.keras.layers.Layer in module_member.__bases__
...@@ -458,7 +460,7 @@ class TFModelTesterMixin: ...@@ -458,7 +460,7 @@ class TFModelTesterMixin:
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
# TODO: A better way to handle vision models # TODO: A better way to handle vision models
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]: elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]:
inputs = tf.keras.Input( inputs = tf.keras.Input(
batch_shape=( batch_shape=(
3, 3,
...@@ -469,6 +471,20 @@ class TFModelTesterMixin: ...@@ -469,6 +471,20 @@ class TFModelTesterMixin:
name="pixel_values", name="pixel_values",
dtype="float32", dtype="float32",
) )
elif model_class.__name__ in ["TFCLIPModel"]:
inputs = {
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
"pixel_values": tf.keras.Input(
batch_shape=(
3,
self.model_tester.vision_model_tester.num_channels,
self.model_tester.vision_model_tester.image_size,
self.model_tester.vision_model_tester.image_size,
),
name="pixel_values",
dtype="float32",
),
}
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else: else:
...@@ -1244,6 +1260,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None): ...@@ -1244,6 +1260,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
return output return output
def random_attention_mask(shape, rng=None, name=None, dtype=None):
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
# make sure that at least one token is attended to for each batch
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1)
return attn_mask
def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None): def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
"""Creates a random float32 tensor""" """Creates a random float32 tensor"""
if rng is None: if rng is None:
......
...@@ -22,7 +22,7 @@ import unittest ...@@ -22,7 +22,7 @@ import unittest
from transformers import ViTConfig from transformers import ViTConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import require_tf, require_vision, slow from transformers.testing_utils import require_tf, require_vision, slow, tooslow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from .test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
...@@ -200,7 +200,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -200,7 +200,7 @@ class TFViTModelTest(TFModelTesterMixin, unittest.TestCase):
# overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated # overwrite from common since `encoder_seq_length` and `encoder_key_length` are calculated
# in a different way than in text models. # in a different way than in text models.
@slow @tooslow
def test_saved_model_creation_extended(self): def test_saved_model_creation_extended(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.output_hidden_states = True config.output_hidden_states = True
......
...@@ -111,6 +111,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -111,6 +111,8 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"BeitForMaskedImageModeling", "BeitForMaskedImageModeling",
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
"TFCLIPTextModel",
"TFCLIPVisionModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForCTC",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment