Unverified Commit e6d23a4b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Improve test_pt_tf_model_equivalence on PT side (#16731)



* Update test_pt_tf_model_equivalence on PT side
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 3dd57b15
...@@ -28,7 +28,6 @@ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig ...@@ -28,7 +28,6 @@ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flax_available, is_flax_available,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_pt_tf_cross_test,
require_torch, require_torch,
require_vision, require_vision,
slow, slow,
...@@ -602,149 +601,6 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -602,149 +601,6 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase):
text_config = CLIPTextConfig.from_pretrained(tmp_dir_name) text_config = CLIPTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict()) self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
# overwrite from common since CLIPModel/TFCLIPModel return CLIPOutput/TFCLIPOutput
@is_pt_tf_cross_test
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return
tf_model_class = getattr(transformers, tf_model_class_name)
config.output_hidden_states = True
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys}
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tf_inputs_dict[key] = tensor
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
pt_model = pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
tf_inputs_dict = {}
for key, tensor in pt_inputs.items():
# skip key that does not exist in tf
if type(tensor) == bool:
tensor = np.array(tensor, dtype=bool)
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
else:
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
# need to rename encoder-decoder "inputs" for PyTorch
# if "inputs" in pt_inputs_dict and self.is_encoder_decoder:
# pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs")
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict)
self.assertEqual(len(tfo), len(pto), "Output lengths differ between TF and PyTorch")
for tf_output, pt_output in zip(tfo.to_tuple(), pto.to_tuple()):
if not (isinstance(tf_output, tf.Tensor) and isinstance(pt_output, torch.Tensor)):
continue
tf_out = tf_output.numpy()
pt_out = pt_output.cpu().numpy()
self.assertEqual(tf_out.shape, pt_out.shape, "Output component shapes differ between TF and PyTorch")
if len(tf_out.shape) > 0:
tf_nans = np.copy(np.isnan(tf_out))
pt_nans = np.copy(np.isnan(pt_out))
pt_out[tf_nans] = 0
tf_out[tf_nans] = 0
pt_out[pt_nans] = 0
tf_out[pt_nans] = 0
max_diff = np.amax(np.abs(tf_out - pt_out))
self.assertLessEqual(max_diff, 4e-2)
# overwrite from common since FlaxCLIPModel returns nested output # overwrite from common since FlaxCLIPModel returns nested output
# which is not supported in the common test # which is not supported in the common test
@is_pt_flax_cross_test @is_pt_flax_cross_test
......
...@@ -15,16 +15,13 @@ ...@@ -15,16 +15,13 @@
import copy import copy
import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
import transformers
from transformers import LxmertConfig, is_tf_available, is_torch_available from transformers import LxmertConfig, is_tf_available, is_torch_available
from transformers.models.auto import get_values from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
from ..test_modeling_common import ModelTesterMixin, ids_tensor from ..test_modeling_common import ModelTesterMixin, ids_tensor
...@@ -527,6 +524,8 @@ class LxmertModelTester: ...@@ -527,6 +524,8 @@ class LxmertModelTester:
if return_obj_labels: if return_obj_labels:
inputs_dict["obj_labels"] = obj_labels inputs_dict["obj_labels"] = obj_labels
else:
config.task_obj_predict = False
return config, inputs_dict return config, inputs_dict
...@@ -740,121 +739,30 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -740,121 +739,30 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
self.assertIsNotNone(hidden_states_vision.grad) self.assertIsNotNone(hidden_states_vision.grad)
self.assertIsNotNone(attentions_vision.grad) self.assertIsNotNone(attentions_vision.grad)
@is_pt_tf_cross_test def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
def test_pt_tf_model_equivalence(self):
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
return_obj_labels="PreTraining" in model_class.__name__
)
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
if not hasattr(transformers, tf_model_class_name):
# transformers does not have TF version yet
return
tf_model_class = getattr(transformers, tf_model_class_name)
config.output_hidden_states = True tf_inputs_dict = {}
config.task_obj_predict = False for key, value in pt_inputs_dict.items():
# skip key that does not exist in tf
pt_model = model_class(config)
tf_model = tf_model_class(config)
# Check we can load pt model in tf and vice-versa with model => model functions
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
def recursive_numpy_convert(iterable):
return_dict = {}
for key, value in iterable.items():
if type(value) == bool:
return_dict[key] = value
if isinstance(value, dict): if isinstance(value, dict):
return_dict[key] = recursive_numpy_convert(value) tf_inputs_dict[key] = self.prepare_pt_inputs_from_tf_inputs(value)
else: elif isinstance(value, (list, tuple)):
if isinstance(value, (list, tuple)): tf_inputs_dict[key] = (self.prepare_pt_inputs_from_tf_inputs(iter_value) for iter_value in value)
return_dict[key] = ( elif type(value) == bool:
tf.convert_to_tensor(iter_value.cpu().numpy(), dtype=tf.int32) for iter_value in value tf_inputs_dict[key] = value
) elif key == "input_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "pixel_values":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
elif key == "input_features":
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
# other general float inputs
elif value.is_floating_point():
tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.float32)
else: else:
return_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32) tf_inputs_dict[key] = tf.convert_to_tensor(value.cpu().numpy(), dtype=tf.int32)
return return_dict
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model).to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
# Delete obj labels as we want to compute the hidden states and not the loss
if "obj_labels" in inputs_dict:
del inputs_dict["obj_labels"]
pt_inputs = self._prepare_for_class(inputs_dict, model_class)
tf_inputs_dict = recursive_numpy_convert(pt_inputs)
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict, training=False)
tf_hidden_states = tfo[0].numpy()
pt_hidden_states = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tf_hidden_states))
pt_nans = np.copy(np.isnan(pt_hidden_states))
pt_hidden_states[tf_nans] = 0
tf_hidden_states[tf_nans] = 0
pt_hidden_states[pt_nans] = 0
tf_hidden_states[pt_nans] = 0
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
# Debug info (remove when fixed)
if max_diff >= 2e-2:
print("===")
print(model_class)
print(config)
print(inputs_dict)
print(pt_inputs)
self.assertLessEqual(max_diff, 6e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
for key, value in pt_inputs.items():
if key in ("visual_feats", "visual_pos"):
pt_inputs[key] = value.to(torch.float32)
else:
pt_inputs[key] = value.to(torch.long)
with torch.no_grad():
pto = pt_model(**pt_inputs)
tfo = tf_model(tf_inputs_dict)
tfo = tfo[0].numpy()
pto = pto[0].cpu().numpy()
tf_nans = np.copy(np.isnan(tfo))
pt_nans = np.copy(np.isnan(pto))
pto[tf_nans] = 0
tfo[tf_nans] = 0
pto[pt_nans] = 0
tfo[pt_nans] = 0
max_diff = np.amax(np.abs(tfo - pto)) return tf_inputs_dict
self.assertLessEqual(max_diff, 6e-2)
@require_torch @require_torch
......
This diff is collapsed.
...@@ -565,8 +565,7 @@ class TFModelTesterMixin: ...@@ -565,8 +565,7 @@ class TFModelTesterMixin:
# Output all for aggressive testing # Output all for aggressive testing
config.output_hidden_states = True config.output_hidden_states = True
if self.has_attentions: config.output_attentions = self.has_attentions
config.output_attentions = True
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
......
...@@ -17,14 +17,13 @@ ...@@ -17,14 +17,13 @@
import inspect import inspect
import math import math
import os
import tempfile import tempfile
import unittest import unittest
import numpy as np import numpy as np
from transformers import ViTMAEConfig from transformers import ViTMAEConfig
from transformers.testing_utils import is_pt_tf_cross_test, require_torch, require_vision, slow, torch_device from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property, is_torch_available, is_vision_available from transformers.utils import cached_property, is_torch_available, is_vision_available
from ..test_configuration_common import ConfigTester from ..test_configuration_common import ConfigTester
...@@ -321,150 +320,20 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -321,150 +320,20 @@ class ViTMAEModelTest(ModelTesterMixin, unittest.TestCase):
# overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise # overwrite from common since ViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test # to generate masks during test
@is_pt_tf_cross_test def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
def test_pt_tf_model_equivalence(self):
import numpy as np
import tensorflow as tf
import transformers
# make masks reproducible # make masks reproducible
np.random.seed(2) np.random.seed(2)
config, _ = self.model_tester.prepare_config_and_inputs_for_common() num_patches = int((pt_model.config.image_size // pt_model.config.patch_size) ** 2)
num_patches = int((config.image_size // config.patch_size) ** 2)
noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches)) noise = np.random.uniform(size=(self.model_tester.batch_size, num_patches))
pt_noise = torch.from_numpy(noise).to(device=torch_device) pt_noise = torch.from_numpy(noise)
tf_noise = tf.constant(noise)
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
tf_inputs_dict = {}
for key, tensor in pt_inputs_dict.items():
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
return tf_inputs_dict
def check_outputs(tf_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
debugging easier and faster.
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
"""
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
if type(tf_outputs) in [tuple, list]:
self.assertEqual(type(tf_outputs), type(pt_outputs))
self.assertEqual(len(tf_outputs), len(pt_outputs))
if type(names) == tuple:
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
check_outputs(tf_output, pt_output, model_class, names=name)
elif type(names) == str:
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(tf_outputs, tf.Tensor):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
tf_outputs = tf_outputs.numpy()
if isinstance(tf_outputs, np.float32):
tf_outputs = np.array(tf_outputs, dtype=np.float32)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
tf_nans = np.isnan(tf_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[tf_nans] = 0
tf_outputs[tf_nans] = 0
pt_outputs[pt_nans] = 0
tf_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
)
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict):
# we are not preparing a model with labels because of the formation
# of the ViT MAE model
# send pytorch model to the correct device
pt_model.to(torch_device)
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
pt_model.eval()
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
# send pytorch inputs to the correct device
pt_inputs_dict = {
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
}
# Original test: check without `labels`
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs_dict, noise=pt_noise)
tf_outputs = tf_model(tf_inputs_dict, noise=tf_noise)
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(tf_keys, pt_keys)
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
# Output all for aggressive testing
config.output_hidden_states = True
config.output_attentions = self.has_attentions
tf_model_class = getattr(transformers, tf_model_class_name)
tf_model = tf_model_class(config)
pt_model = model_class(config)
# make sure only tf inputs are forward that actually exist in function args
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
# remove all head masks
tf_input_keys.discard("head_mask")
tf_input_keys.discard("cross_attn_head_mask")
tf_input_keys.discard("decoder_head_mask")
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
# Check we can load pt model in tf and vice-versa with model => model functions
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") # Add `noise` argument.
tf_model.save_weights(tf_checkpoint_path) # PT inputs will be prepared in `super().check_pt_tf_models()` with this added `noise` argument
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) pt_inputs_dict["noise"] = pt_noise
pt_model = pt_model.to(torch_device)
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict) super().check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
def test_save_load(self): def test_save_load(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment