Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
66e86567
Unverified
Commit
66e86567
authored
Jun 08, 2022
by
Joao Gante
Committed by
GitHub
Jun 08, 2022
Browse files
CLI: Print all different tensors on exception (#17612)
parent
e9d51387
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
28 deletions
+31
-28
src/transformers/commands/pt_to_tf.py
src/transformers/commands/pt_to_tf.py
+31
-28
No files found.
src/transformers/commands/pt_to_tf.py
View file @
66e86567
...
...
@@ -87,10 +87,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
train_parser
.
set_defaults
(
func
=
convert_command_factory
)
@
staticmethod
def
compare_pt_tf_model
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
def
find_pt_tf_difference
s
(
pt_model
,
pt_input
,
tf_model
,
tf_input
):
"""
Compares the TensorFlow and PyTorch models, given their inputs, returning a
tuple with the maximum observed
difference
and its source
.
Compares the TensorFlow and PyTorch models, given their inputs, returning a
dictionary with all tensor
difference
s
.
"""
pt_outputs
=
pt_model
(
**
pt_input
,
output_hidden_states
=
True
)
tf_outputs
=
tf_model
(
**
tf_input
,
output_hidden_states
=
True
)
...
...
@@ -104,18 +104,14 @@ class PTtoTFCommand(BaseTransformersCLICommand):
f
"
{
tf_out_attrs
}
)"
)
# 2. For each output attribute, ALL values must be the same
def
_compate_pt_tf_models
(
pt_out
,
tf_out
,
attr_name
=
""
):
max_difference
=
0
max_difference_source
=
""
# 2. For each output attribute, computes the difference
def
_find_pt_tf_differences
(
pt_out
,
tf_out
,
differences
,
attr_name
=
""
):
# If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in
# recursivelly, keeping the name of the attribute.
if
isinstance
(
pt_out
,
(
torch
.
Tensor
)):
difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
if
difference
>
max_difference
:
max_difference
=
difference
max_difference_source
=
attr_name
if
isinstance
(
pt_out
,
torch
.
Tensor
):
tensor_difference
=
np
.
max
(
np
.
abs
(
pt_out
.
detach
().
numpy
()
-
tf_out
.
numpy
()))
differences
[
attr_name
]
=
tensor_difference
else
:
root_name
=
attr_name
for
i
,
pt_item
in
enumerate
(
pt_out
):
...
...
@@ -127,14 +123,11 @@ class PTtoTFCommand(BaseTransformersCLICommand):
else
:
branch_name
=
root_name
+
f
"[
{
i
}
]"
tf_item
=
tf_out
[
i
]
difference
,
difference_source
=
_compate_pt_tf_models
(
pt_item
,
tf_item
,
branch_name
)
if
difference
>
max_difference
:
max_difference
=
difference
max_difference_source
=
difference_source
differences
=
_find_pt_tf_differences
(
pt_item
,
tf_item
,
differences
,
branch_name
)
return
max_
difference
,
max_difference_source
return
difference
s
return
_
compate_pt_tf_model
s
(
pt_outputs
,
tf_outputs
)
return
_
find_pt_tf_difference
s
(
pt_outputs
,
tf_outputs
,
{}
)
def
__init__
(
self
,
model_name
:
str
,
local_dir
:
str
,
no_pr
:
bool
,
new_weights
:
bool
,
*
args
):
self
.
_logger
=
logging
.
get_logger
(
"transformers-cli/pt_to_tf"
)
...
...
@@ -213,11 +206,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_input
.
update
({
"decoder_input_ids"
:
tf
.
convert_to_tensor
(
decoder_input_ids
)})
# Confirms that cross loading PT weights into TF worked.
crossload_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
if
crossload_diff
>=
MAX_ERROR
:
crossload_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_from_pt_model
,
tf_input
)
max_crossload_diff
=
max
(
crossload_differences
.
values
())
if
max_crossload_diff
>
MAX_ERROR
:
raise
ValueError
(
"The cross-loaded TF model has different outputs, something went wrong! (max difference ="
f
"
{
crossload_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
"The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of"
f
" maximum tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
crossload_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
# Save the weights in a TF format (if needed) and confirms that the results are still good
...
...
@@ -226,11 +223,15 @@ class PTtoTFCommand(BaseTransformersCLICommand):
tf_from_pt_model
.
save_weights
(
tf_weights_path
)
del
tf_from_pt_model
# will no longer be used, and may have a large memory footprint
tf_model
=
tf_class
.
from_pretrained
(
self
.
_local_dir
)
converted_diff
,
diff_source
=
self
.
compare_pt_tf_models
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
if
converted_diff
>=
MAX_ERROR
:
conversion_differences
=
self
.
find_pt_tf_differences
(
pt_model
,
pt_input
,
tf_model
,
tf_input
)
max_conversion_diff
=
max
(
conversion_differences
.
values
())
if
max_conversion_diff
>
MAX_ERROR
:
raise
ValueError
(
"The converted TF model has different outputs, something went wrong! (max difference ="
f
"
{
converted_diff
:.
3
e
}
, observed in
{
diff_source
}
)"
"The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum"
f
" tensor differences above the error threshold (
{
MAX_ERROR
}
):
\n
"
+
"
\n
"
.
join
(
[
f
"
{
key
}
:
{
value
:.
3
e
}
"
for
key
,
value
in
conversion_differences
.
items
()
if
value
>
MAX_ERROR
]
)
)
if
not
self
.
_no_pr
:
...
...
@@ -245,8 +246,10 @@ class PTtoTFCommand(BaseTransformersCLICommand):
create_pr
=
True
,
pr_commit_summary
=
"Add TF weights"
,
pr_commit_description
=
(
f
"Validated by the `pt_to_tf` CLI. Max crossload output difference=
{
crossload_diff
:.
3
e
}
;"
f
" Max converted output difference=
{
converted_diff
:.
3
e
}
."
"Model converted by the `transformers`' `pt_to_tf` CLI -- all converted model outputs and"
" hidden layers were validated against its Pytorch counterpart. Maximum crossload output"
f
" difference=
{
max_crossload_diff
:.
3
e
}
; Maximum converted output"
f
" difference=
{
max_conversion_diff
:.
3
e
}
."
),
)
self
.
_logger
.
warn
(
f
"PR open in
{
hub_pr_url
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment