"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "12313838d33373d06d35b48c3c501fa832f16443"
Unverified Commit dce33f21 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Improve PT/TF equivalence test (#16557)



* add error message

* Use names in the error message

* allow ModelOutput

* rename to check_pt_tf_outputs and move outside

* fix style

* skip past_key_values in a better way

* Add comments

* improve code for label/loss

* make the logic clear by moving the ignore keys out

* fix _postprocessing_to_ignore

* fix _postprocessing_to_ignore: create new outputs from the remaining fields

* ignore past_key_values in TFGPT2 models for now

* make check_pt_tf_outputs better regarding names

* move check_pt_tf_models outside

* rename methods

* remove test_pt_tf_model_equivalence in TFCLIPModelTest

* Reduce TFViTMAEModelTest.test_pt_tf_model_equivalence

* move prepare_pt_inputs_from_tf_inputs outside check_pt_tf_models

* Fix quality

* Clean-up TFLxmertModelTester.test_pt_tf_model_equivalence

* Fix quality

* fix

* fix style

* Clean-up TFLEDModelTest.test_pt_tf_model_equivalence

* Fix quality

* add docstring

* improve comment
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 7f730085
...@@ -23,7 +23,7 @@ from importlib import import_module ...@@ -23,7 +23,7 @@ from importlib import import_module
import requests import requests
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow from transformers.testing_utils import require_tf, require_vision, slow
from transformers.utils import is_tf_available, is_vision_available from transformers.utils import is_tf_available, is_vision_available
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
...@@ -31,7 +31,6 @@ from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_ten ...@@ -31,7 +31,6 @@ from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_ten
if is_tf_available(): if is_tf_available():
import numpy as np
import tensorflow as tf import tensorflow as tf
from transformers import TFCLIPModel, TFCLIPTextModel, TFCLIPVisionModel, TFSharedEmbeddings from transformers import TFCLIPModel, TFCLIPTextModel, TFCLIPVisionModel, TFSharedEmbeddings
...@@ -497,130 +496,6 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -497,130 +496,6 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
after_outputs = model(inputs_dict) after_outputs = model(inputs_dict)
self.assert_outputs_same(after_outputs, outputs) self.assert_outputs_same(after_outputs, outputs)
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import torch
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = {}
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
if type(key) == bool:
key = np.array(key, dtype=bool)
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in TF_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -17,14 +17,13 @@ ...@@ -17,14 +17,13 @@
import unittest import unittest
from transformers import LEDConfig, is_tf_available from transformers import LEDConfig, is_tf_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, slow from transformers.testing_utils import require_tf, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, ids_tensor
if is_tf_available(): if is_tf_available():
import numpy as np
import tensorflow as tf import tensorflow as tf
from transformers import TFLEDForConditionalGeneration, TFLEDModel from transformers import TFLEDForConditionalGeneration, TFLEDModel
...@@ -362,128 +361,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -362,128 +361,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
check_encoder_attentions_output(outputs) check_encoder_attentions_output(outputs)
# TODO: Remove this once a more thorough pt/tf equivalence could be implemented in `test_modeling_tf_common.py`.
# (Currently, such a test will fail some other model tests: it requires some time to fix them.)
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence_extra(self):
import torch
import transformers
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
pt_inputs_dict = {}
for name, key in tf_inputs_dict.items():
if type(key) == bool:
pt_inputs_dict[name] = key
elif name == "input_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
elif name == "pixel_values":
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
else:
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
return pt_inputs_dict
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
config.output_hidden_states = True
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels)
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
tf_nans = np.isnan(tf_hidden_states)
pt_nans = np.isnan(pt_hidden_states)
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
self.assertLessEqual(max_diff, 1e-4)
has_labels = any(
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
)
if has_labels:
with torch.no_grad():
pto = pt_model(**pt_inputs_dict_maybe_with_labels)
tfo = tf_model(tf_inputs_dict_maybe_with_labels, training=False)
# Some models' output class don't have `loss` attribute despite `labels` is used.
tf_loss = getattr(tfo, "loss", None)
pt_loss = getattr(pto, "loss", None)
# Some models require extra condition to return loss. For example, `BertForPreTraining` requires both
# `labels` and `next_sentence_label`.
# Moreover, some PT models return loss while the corresponding TF/Flax models don't.
if tf_loss is not None and pt_loss is not None:
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
pt_loss = pt_loss.numpy()
tf_nans = np.isnan(tf_loss)
pt_nans = np.isnan(pt_loss)
# the 2 losses need to be both nan or both not nan
# (`TapasForQuestionAnswering` gives nan loss here)
self.assertEqual(tf_nans, pt_nans)
if not tf_nans:
max_diff = np.amax(np.abs(tf_loss - pt_loss))
# `TFFunnelForTokenClassification` (and potentially other TF token classification models) give
# large difference (up to 0.1x). PR #15294 addresses this issue.
# There is also an inconsistency between PT/TF `XLNetLMHeadModel`.
# Before these issues are fixed & merged, set a higher threshold here to pass the test.
self.assertLessEqual(max_diff, 1e-4)
tf_logits = tfo[1].numpy()
pt_logits = pto[1].numpy()
# check on the shape
self.assertEqual(tf_logits.shape, pt_logits.shape)
tf_nans = np.isnan(tf_logits)
pt_nans = np.isnan(pt_logits)
pt_logits[tf_nans] = 0
tf_logits[tf_nans] = 0
pt_logits[pt_nans] = 0
tf_logits[pt_nans] = 0
max_diff = np.amax(np.abs(tf_logits - pt_logits))
self.assertLessEqual(max_diff, 1e-4)
def test_xla_mode(self): def test_xla_mode(self):
# TODO JP: Make LED XLA compliant # TODO JP: Make LED XLA compliant
pass pass
......
...@@ -272,6 +272,8 @@ class TFLxmertModelTester(object): ...@@ -272,6 +272,8 @@ class TFLxmertModelTester(object):
if return_obj_labels: if return_obj_labels:
inputs_dict["obj_labels"] = obj_labels inputs_dict["obj_labels"] = obj_labels
else:
config.task_obj_predict = False
return config, inputs_dict return config, inputs_dict
...@@ -486,135 +488,31 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -486,135 +488,31 @@ class TFLxmertModelTest(TFModelTesterMixin, unittest.TestCase):
config.output_hidden_states = True config.output_hidden_states = True
check_hidden_states_output(config, inputs_dict, model_class) check_hidden_states_output(config, inputs_dict, model_class)
def test_pt_tf_model_equivalence(self): def prepare_pt_inputs_from_tf_inputs(self, tf_inputs_dict):
from transformers import is_torch_available
if not is_torch_available():
return
import torch import torch
import transformers pt_inputs_dict = {}
for key, value in tf_inputs_dict.items():
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( if isinstance(value, dict):
return_obj_labels="PreTraining" in model_class.__name__ pt_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
) elif isinstance(value, (list, tuple)):
pt_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning elif type(key) == bool:
pt_model_class = getattr(transformers, pt_model_class_name) pt_inputs_dict[key] = value
elif key == "input_values":
config.output_hidden_states = True pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
config.task_obj_predict = False elif key == "pixel_values":
pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
tf_model = model_class(config) elif key == "input_features":
pt_model = pt_model_class(config) pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
# other general float inputs
# Check we can load pt model in tf and vice-versa with model => model functions elif tf_inputs_dict[key].dtype.is_floating:
pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.float32)
tf_model = transformers.load_pytorch_model_in_tf2_model( else:
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class) pt_inputs_dict[key] = torch.from_numpy(value.numpy()).to(torch.long)
)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
# Delete obj labels as we want to compute the hidden states and not the loss
if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]
def torch_type(key):
if key in ("visual_feats", "visual_pos"):
return torch.float32
else:
return torch.long
def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value)
else:
if isinstance(value, (list, tuple)):
return_dict[key] = (
torch.from_numpy(iter_value.numpy()).to(torch_type(key)) for iter_value in value
)
else:
return_dict[key] = torch.from_numpy(value.numpy()).to(torch_type(key))
return return_dict
pt_inputs_dict = recursive_numpy_convert(self._prepare_for_class(inputs_dict, model_class))
# need to rename encoder-decoder "inputs" for PyTorch
if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 2e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs_dict)
self.assertLessEqual(max_diff, 6e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
import os
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict(
(name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in self._prepare_for_class(inputs_dict, model_class).items()
)
for key, value in pt_inputs_dict.items(): return pt_inputs_dict
if key in ("visual_feats", "visual_pos"):
pt_inputs_dict[key] = value.to(torch.float32)
else:
pt_inputs_dict[key] = value.to(torch.long)
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
tfo = tfo[0].numpy()
pto = pto[0].numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto))
self.assertLessEqual(max_diff, 6e-2)
def test_save_load(self): def test_save_load(self):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
......
This diff is collapsed.
...@@ -28,7 +28,7 @@ import numpy as np ...@@ -28,7 +28,7 @@ import numpy as np
from transformers import ViTMAEConfig from transformers import ViTMAEConfig
from transformers.file_utils import cached_property, is_tf_available, is_vision_available from transformers.file_utils import cached_property, is_tf_available, is_vision_available
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_vision, slow, torch_device from transformers.testing_utils import require_tf, require_vision, slow
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor from ..test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_tensor
...@@ -363,140 +363,20 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -363,140 +363,20 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
@is_pt_tf_cross_test def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict):
def test_pt_tf_model_equivalence(self):
import torch
import transformers
# make masks reproducible # make masks reproducible
np.random.seed(2) np.random.seed(2)
config, _ = self.model_tester.prepare_config_and_inputs_for_common() num_patches = int((tf_model.config.image_size // tf_model.config.patch_size) ** 2)
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches)) noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
pt_noise = torch.from_numpy(noise).to(device=torch_device)
tf_noise = tf.constant(noise) tf_noise = tf.constant(noise)
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict): # Add `noise` argument.
# PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument
pt_inputs_dict = {} tf_inputs_dict["noise"] = tf_noise
for name, key in tf_inputs_dict.items():
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
return pt_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
if isinstance(tf_outputs, np.float32):
tf_outputs = np.array(tf_outputs, dtype=np.float32)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
# Set a higher tolerance (2e-5) here than the one in the common test (1e-5).
# TODO: A deeper look to decide the best (common) tolerance for the test to be strict but not too flaky.
self.assertLessEqual(max_diff, 2e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model):
# we are not preparing a model with labels because of the formation
# of the ViT MAE model
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# Original test: check without `labels`
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Output all for aggressive testing
config.output_hidden_states = True
if self.has_attentions:
config.output_attentions = True
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
check_pt_tf_models(tf_model, pt_model) super().check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
# overwrite from common since TFViTMAEForPretraining outputs loss along with # overwrite from common since TFViTMAEForPretraining outputs loss along with
# logits and mask indices. loss and mask indicies are not suitable for integration # logits and mask indices. loss and mask indicies are not suitable for integration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment