Commit 6dce6dda authored by thomwolf's avatar thomwolf
Browse files

fixing TF 2.0 model - adding more severe test on pt/tf equivalence

parent c56d921d
......@@ -71,6 +71,8 @@ class TFCommonTestCases:
if not is_torch_available():
return
import torch
import numpy as np
import transformers
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -79,12 +81,22 @@ class TFCommonTestCases:
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beggining
pt_model_class = getattr(transformers, pt_model_class_name)
tf_model = model_class(config)
pt_model = pt_model_class(config)
tf_model = model_class(config, output_hidden_states=True)
pt_model = pt_model_class(config, output_hidden_states=True)
# Check we can load pt model in tf and vice-versa (architecture similar)
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in inputs_dict.items())
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_keyword_and_dict_args(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment