Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66e86567
Unverified
Commit
66e86567
authored
Jun 08, 2022
by
Joao Gante
Committed by
GitHub
Jun 08, 2022
Browse files
CLI: Print all different tensors on exception (#17612)
parent
e9d51387
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
28 deletions
+31
-28
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+31
-28
No files found.
src/transformers/commands/pt_to_tf.py
View file @
66e86567
...
...
@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
@
staticmethod
def
compare_pt_tf_model
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
def
find_pt_tf_difference
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a
tuple with the maximum observed
difference
and its source
.
Compares the TensorFlow and PyTorch models, given their inputs, returning a
dictionary with all tensor
difference
s
.
"""
pt_outputs
=
pt_model
(
**
pt_input
,
output_hidden_states
=
True
)
tf_outputs
=
tf_model
(
**
tf_input
,
output_hidden_states
=
True
)
...
...
@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f
"
{
tf_out_attrs
}
)"
)
# 2. For each output attribute, ALL values must be the same
def
_compate_pt_tf_models
(
pt_out
,
tf_out
,
attr_name
=
""
):
max_difference
=
0
max_difference_source
=
""
# 2. For each output attribute, computes the difference
def
_find_pt_tf_differences
(
pt_out
,
tf_out
,
differences
,
attr_name
=
""
):
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if
isinstance
(
pt_out
,
(
torch
.
Tensor
)):
difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
if
difference
>
max_difference
:
max_difference
=
difference
max_difference_source
=
attr_name
if
isinstance
(
pt_out
,
torch
.
Tensor
):
tensor_difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
differences
[
attr_name
]
=
tensor_difference
else
:
root_name
=
attr_name
for
i
,
pt_item
in
enumerate
(
pt_out
):
...
...
@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
else
:
branch_name
=
root_name
+
f
"[
{
i
}
]"
tf_item
=
tf_out
[
i
]
difference
,
difference_source
=
_compate_pt_tf_models
(
pt_item
,
tf_item
,
branch_name
)
if
difference
>
max_difference
:
max_difference
=
difference
max_difference_source
=
difference_source
differences
=
_find_pt_tf_differences
(
pt_item
,
tf_item
,
differences
,
branch_name
)
return
max_
difference
,
max_difference_source
return
difference
s
return
_
compate_pt_tf_model
s
(
pt_outputs
,
tf_outputs
)
return
_
find_pt_tf_difference
s
(
pt_outputs
,
tf_outputs
,
{}
)
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
...
...
@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
crossload_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
if
crossload_diff
>=
MAX_ERROR
:
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
max_crossload_diff
=
max
(
crossload_differences
.
values
())
if
max_crossload_diff
>
MAX_ERROR
:
raise
ValueError
(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f
"
{
crossload_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f
" maximum tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
crossload_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
...
...
@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_from_pt_model
.
save_weights
(
tf_weights_path
)
del
tf_from_pt_model
# will no longer be used, and may have a large memory footprint
tf_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
)
converted_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
if
converted_diff
>=
MAX_ERROR
:
conversion_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
max_conversion_diff
=
max
(
conversion_differences
.
values
())
if
max_conversion_diff
>
MAX_ERROR
:
raise
ValueError
(
"The converted TF model has different outputs, something went wrong! (max difference ="
f
"
{
converted_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f
" tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
conversion_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
if
not
self
.
_no_pr
:
...
...
@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
create_pr
=
True
,
pr_commit_summary
=
"Add TF weights"
,
pr_commit_description
=
(
f
"Validated by the `pt_to_tf` CLI. Max crossload output difference=
{
crossload_diff
:.
3
e
}
;"
f
" Max converted output difference=
{
converted_diff
:.
3
e
}
."
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f
" difference=
{
max_crossload_diff
:.
3
e
}
; Maximum converted output"
f
" difference=
{
max_conversion_diff
:.
3
e
}
."
),
)
self
.
_logger
.
warn
(
f
"PR open in
{
hub_pr_url
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment