Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
898ce064
"examples/vscode:/vscode.git/clone" did not exist on "b7374ad4a5eb4c7efa5a2702dcc7f418dd9cf13d"
Commit
898ce064
authored
Oct 15, 2019
by
thomwolf
Browse files
add tests on TF2.0 & PT checkpoint => model convertion functions
parent
09935867
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
1 deletion
+22
-1
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+22
-1
No files found.
transformers/tests/modeling_tf_common_test.py
View file @
898ce064
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
import
os
import
copy
import
json
import
logging
...
...
@@ -118,7 +119,7 @@ class TFCommonTestCases:
tf_model
=
model_class
(
config
)
pt_model
=
pt_model_class
(
config
)
# Check we can load pt model in tf and vice-versa
(architecture similar)
# Check we can load pt model in tf and vice-versa
with model => model functions
tf_model
=
transformers
.
load_pytorch_model_in_tf2_model
(
tf_model
,
pt_model
,
tf_inputs
=
inputs_dict
)
pt_model
=
transformers
.
load_tf2_model_in_pytorch_model
(
pt_model
,
tf_model
)
...
...
@@ -132,6 +133,26 @@ class TFCommonTestCases:
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
2e-2
)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
with
TemporaryDirectory
()
as
tmpdirname
:
pt_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
'pt_model.bin'
)
torch
.
save
(
pt_model
.
state_dict
(),
pt_checkpoint_path
)
tf_model
=
transformers
.
load_pytorch_checkpoint_in_tf2_model
(
tf_model
,
pt_checkpoint_path
)
tf_checkpoint_path
=
os
.
path
.
join
(
tmpdirname
,
'tf_model.h5'
)
tf_model
.
save_weights
(
tf_checkpoint_path
)
pt_model
=
transformers
.
load_tf2_checkpoint_in_pytorch_model
(
pt_model
,
tf_checkpoint_path
)
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
pt_model
.
eval
()
pt_inputs_dict
=
dict
((
name
,
torch
.
from_numpy
(
key
.
numpy
()).
to
(
torch
.
long
))
for
name
,
key
in
inputs_dict
.
items
())
with
torch
.
no_grad
():
pto
=
pt_model
(
**
pt_inputs_dict
)
tfo
=
tf_model
(
inputs_dict
)
max_diff
=
np
.
amax
(
np
.
abs
(
tfo
[
0
].
numpy
()
-
pto
[
0
].
numpy
()))
self
.
assertLessEqual
(
max_diff
,
2e-2
)
def
test_compile_tf_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment