"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ece0903e111b96718a0d845fa358c9581d03f15d"
Commit 898ce064 authored by thomwolf's avatar thomwolf
Browse files

add tests on TF2.0 & PT checkpoint => model convertion functions

parent 09935867
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os
import copy import copy
import json import json
import logging import logging
...@@ -118,7 +119,7 @@ class TFCommonTestCases: ...@@ -118,7 +119,7 @@ class TFCommonTestCases:
tf_model = model_class(config) tf_model = model_class(config)
pt_model = pt_model_class(config) pt_model = pt_model_class(config)
# Check we can load pt model in tf and vice-versa (architecture similar) # Check we can load pt model in tf and vice-versa with model => model functions
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict) tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=inputs_dict)
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
...@@ -132,6 +133,26 @@ class TFCommonTestCases: ...@@ -132,6 +133,26 @@ class TFCommonTestCases:
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy())) max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2) self.assertLessEqual(max_diff, 2e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, 'pt_model.bin')
torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
tf_checkpoint_path = os.path.join(tmpdirname, 'tf_model.h5')
tf_model.save_weights(tf_checkpoint_path)
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model.eval()
pt_inputs_dict = dict((name, torch.from_numpy(key.numpy()).to(torch.long))
for name, key in inputs_dict.items())
with torch.no_grad():
pto = pt_model(**pt_inputs_dict)
tfo = tf_model(inputs_dict)
max_diff = np.amax(np.abs(tfo[0].numpy() - pto[0].numpy()))
self.assertLessEqual(max_diff, 2e-2)
def test_compile_tf_model(self): def test_compile_tf_model(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment